Skip to content

Commit 4a418e2

Browse files
committed
简化代码
1 parent 72c4b89 commit 4a418e2

File tree

7 files changed

+273
-558
lines changed

7 files changed

+273
-558
lines changed

api/v1/chat.py

Lines changed: 130 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,126 @@ def verify_auth(authorization: str = Header(None)) -> bool:
122122
return True
123123

124124

125+
async def _run_engine_with_streaming(
126+
engine,
127+
request_id: str,
128+
created: int,
129+
model: str,
130+
thinking_generator
131+
) -> AsyncIterator[str]:
132+
"""运行引擎并流式输出结果"""
133+
progress_queue = []
134+
135+
def on_progress(event):
136+
"""捕获进度事件"""
137+
progress_queue.append(event)
138+
139+
# 设置进度回调
140+
engine.on_progress = on_progress
141+
142+
# UltraThink 特殊处理
143+
if hasattr(engine, 'on_agent_update'):
144+
def on_agent_update(agent_id: str, update: Dict[str, Any]):
145+
"""捕获 Agent 更新"""
146+
from models import ProgressEvent
147+
progress_queue.append(ProgressEvent(
148+
type="agent-update",
149+
data={"agentId": agent_id, **update}
150+
))
151+
engine.on_agent_update = on_agent_update
152+
153+
# 在后台运行引擎
154+
engine_task = asyncio.create_task(engine.run())
155+
156+
try:
157+
# 流式发送进度
158+
while not engine_task.done():
159+
# 处理队列中的进度事件
160+
while progress_queue:
161+
event = progress_queue.pop(0)
162+
# 如果启用了 summary_think,将事件转换为思维链
163+
if thinking_generator:
164+
thinking_text = thinking_generator.process_event(event)
165+
if thinking_text:
166+
# 使用 reasoning_content 字段输出推理过程
167+
delta = {"reasoning_content": thinking_text}
168+
chunk_data = {
169+
"id": request_id,
170+
"object": "chat.completion.chunk",
171+
"created": created,
172+
"model": model,
173+
"choices": [{
174+
"index": 0,
175+
"delta": delta,
176+
"finish_reason": None
177+
}]
178+
}
179+
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
180+
await asyncio.sleep(0.1) # 短暂等待避免busy loop
181+
182+
# 获取最终结果
183+
result = await engine_task
184+
185+
# 处理剩余的进度事件
186+
while progress_queue:
187+
event = progress_queue.pop(0)
188+
if thinking_generator:
189+
thinking_text = thinking_generator.process_event(event)
190+
if thinking_text:
191+
delta = {"reasoning_content": thinking_text}
192+
chunk_data = {
193+
"id": request_id,
194+
"object": "chat.completion.chunk",
195+
"created": created,
196+
"model": model,
197+
"choices": [{
198+
"index": 0,
199+
"delta": delta,
200+
"finish_reason": None
201+
}]
202+
}
203+
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
204+
205+
# 流式发送最终答案
206+
final_text = result.summary or result.final_solution
207+
for i in range(0, len(final_text), 50):
208+
chunk = final_text[i:i+50]
209+
delta = {"content": chunk}
210+
chunk_data = {
211+
"id": request_id,
212+
"object": "chat.completion.chunk",
213+
"created": created,
214+
"model": model,
215+
"choices": [{
216+
"index": 0,
217+
"delta": delta,
218+
"finish_reason": None
219+
}]
220+
}
221+
yield f"data: {json.dumps(chunk_data)}\n\n"
222+
223+
except GeneratorExit:
224+
# 客户端断开连接,取消引擎任务
225+
logger.info(f"Client disconnected for request {request_id}, cancelling engine task")
226+
if engine_task and not engine_task.done():
227+
engine_task.cancel()
228+
try:
229+
await engine_task
230+
except asyncio.CancelledError:
231+
pass # 预期的取消异常
232+
# 不重新抛出 GeneratorExit,让生成器正常结束
233+
except (asyncio.CancelledError, Exception) as e:
234+
# 其他异常情况,记录日志并取消任务
235+
logger.error(f"Error during streaming for request {request_id}: {e}")
236+
if engine_task and not engine_task.done():
237+
engine_task.cancel()
238+
try:
239+
await engine_task
240+
except asyncio.CancelledError:
241+
pass
242+
raise # 重新抛出异常
243+
244+
125245
async def stream_chat_completion(
126246
request: ChatCompletionRequest,
127247
model_config,
@@ -130,7 +250,6 @@ async def stream_chat_completion(
130250
"""流式聊天补全"""
131251
request_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
132252
created = int(time.time())
133-
engine_task = None # 用于跟踪引擎任务以便在断开时取消
134253

135254
# 提取 LLM 参数
136255
llm_params = extract_llm_params(request)
@@ -178,205 +297,43 @@ async def stream_chat_completion(
178297
else:
179298
thinking_generator = ThinkingSummaryGenerator(mode="deepthink")
180299

181-
# 定义进度处理器 - 将进度事件转换为流式输出
182-
async def stream_progress(event):
183-
"""处理进度事件并流式发送"""
184-
# 如果启用了 summary_think,将事件转换为思维链
185-
if thinking_generator:
186-
thinking_text = thinking_generator.process_event(event)
187-
if thinking_text:
188-
# 使用 reasoning_content 字段输出推理过程
189-
delta = {"reasoning_content": thinking_text}
190-
chunk_data = {
191-
"id": request_id,
192-
"object": "chat.completion.chunk",
193-
"created": created,
194-
"model": request.model,
195-
"choices": [{
196-
"index": 0,
197-
"delta": delta,
198-
"finish_reason": None
199-
}]
200-
}
201-
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
202300

203301
# 根据模型级别选择引擎
204302
if model_config.level == "ultrathink":
205303
# UltraThink 模式
206-
# 使用生成器来捕获进度并流式输出
207-
progress_queue = []
208-
209-
def on_progress(event):
210-
"""捕获进度事件"""
211-
progress_queue.append(event)
212-
213-
def on_agent_update(agent_id: str, update: Dict[str, Any]):
214-
"""捕获 Agent 更新"""
215-
from models import ProgressEvent
216-
progress_queue.append(ProgressEvent(
217-
type="agent-update",
218-
data={"agentId": agent_id, **update}
219-
))
220-
221-
# 运行引擎 - 传递结构化的对话历史和多模态内容
222304
engine = UltraThinkEngine(
223305
client=client,
224306
model=model_config.model,
225-
problem_statement=problem_statement_raw, # 传递多模态内容
226-
conversation_history=conversation_history, # 传递结构化的消息历史
307+
problem_statement=problem_statement_raw,
308+
conversation_history=conversation_history,
227309
max_iterations=model_config.max_iterations,
228310
required_successful_verifications=model_config.required_verifications,
229311
num_agents=model_config.num_agent,
230312
parallel_run_agent=model_config.parallel_run_agent,
231313
model_stages=model_config.models,
232-
on_progress=on_progress,
233-
on_agent_update=on_agent_update,
234314
enable_parallel_check=model_config.parallel_check,
235315
llm_params=llm_params,
236316
)
237-
238-
# 在后台运行引擎
239-
engine_task = asyncio.create_task(engine.run())
240-
241-
try:
242-
# 流式发送进度
243-
while not engine_task.done():
244-
# 处理队列中的进度事件
245-
while progress_queue:
246-
event = progress_queue.pop(0)
247-
async for chunk in stream_progress(event):
248-
yield chunk
249-
await asyncio.sleep(0.1) # 短暂等待避免busy loop
250-
251-
# 获取最终结果
252-
result = await engine_task
253-
254-
# 处理剩余的进度事件
255-
while progress_queue:
256-
event = progress_queue.pop(0)
257-
async for chunk in stream_progress(event):
258-
yield chunk
259-
260-
# 流式发送最终答案
261-
final_text = result.summary or result.final_solution
262-
for i in range(0, len(final_text), 50):
263-
chunk = final_text[i:i+50]
264-
delta = {"content": chunk}
265-
chunk_data = {
266-
"id": request_id,
267-
"object": "chat.completion.chunk",
268-
"created": created,
269-
"model": request.model,
270-
"choices": [{
271-
"index": 0,
272-
"delta": delta,
273-
"finish_reason": None
274-
}]
275-
}
276-
yield f"data: {json.dumps(chunk_data)}\n\n"
277-
except GeneratorExit:
278-
# 客户端断开连接,取消引擎任务
279-
logger.info(f"Client disconnected for request {request_id}, cancelling engine task")
280-
if engine_task and not engine_task.done():
281-
engine_task.cancel()
282-
try:
283-
await engine_task
284-
except asyncio.CancelledError:
285-
pass # 预期的取消异常
286-
# 不重新抛出 GeneratorExit,让生成器正常结束
287-
except (asyncio.CancelledError, Exception) as e:
288-
# 其他异常情况,记录日志并取消任务
289-
logger.error(f"Error during streaming for request {request_id}: {e}")
290-
if engine_task and not engine_task.done():
291-
engine_task.cancel()
292-
try:
293-
await engine_task
294-
except asyncio.CancelledError:
295-
pass
296-
raise # 重新抛出异常
297-
298317
else: # deepthink
299318
# DeepThink 模式
300-
progress_queue = []
301-
302-
def on_progress(event):
303-
"""捕获进度事件"""
304-
progress_queue.append(event)
305-
306-
# 运行引擎 - 传递结构化的对话历史和多模态内容
307319
engine = DeepThinkEngine(
308320
client=client,
309321
model=model_config.model,
310-
problem_statement=problem_statement_raw, # 传递多模态内容
311-
conversation_history=conversation_history, # 传递结构化的消息历史
322+
problem_statement=problem_statement_raw,
323+
conversation_history=conversation_history,
312324
max_iterations=model_config.max_iterations,
313325
required_successful_verifications=model_config.required_verifications,
314326
model_stages=model_config.models,
315-
on_progress=on_progress,
316327
enable_planning=model_config.has_plan_mode,
317328
enable_parallel_check=model_config.parallel_check,
318329
llm_params=llm_params,
319330
)
320-
321-
# 在后台运行引擎
322-
engine_task = asyncio.create_task(engine.run())
323-
324-
try:
325-
# 流式发送进度
326-
while not engine_task.done():
327-
# 处理队列中的进度事件
328-
while progress_queue:
329-
event = progress_queue.pop(0)
330-
async for chunk in stream_progress(event):
331-
yield chunk
332-
await asyncio.sleep(0.1) # 短暂等待避免busy loop
333-
334-
# 获取最终结果
335-
result = await engine_task
336-
337-
# 处理剩余的进度事件
338-
while progress_queue:
339-
event = progress_queue.pop(0)
340-
async for chunk in stream_progress(event):
341-
yield chunk
342-
343-
# 流式发送最终答案
344-
final_text = result.summary or result.final_solution
345-
for i in range(0, len(final_text), 50):
346-
chunk = final_text[i:i+50]
347-
delta = {"content": chunk}
348-
chunk_data = {
349-
"id": request_id,
350-
"object": "chat.completion.chunk",
351-
"created": created,
352-
"model": request.model,
353-
"choices": [{
354-
"index": 0,
355-
"delta": delta,
356-
"finish_reason": None
357-
}]
358-
}
359-
yield f"data: {json.dumps(chunk_data)}\n\n"
360-
except GeneratorExit:
361-
# 客户端断开连接,取消引擎任务
362-
logger.info(f"Client disconnected for request {request_id}, cancelling engine task")
363-
if engine_task and not engine_task.done():
364-
engine_task.cancel()
365-
try:
366-
await engine_task
367-
except asyncio.CancelledError:
368-
pass # 预期的取消异常
369-
# 不重新抛出 GeneratorExit,让生成器正常结束
370-
except (asyncio.CancelledError, Exception) as e:
371-
# 其他异常情况,记录日志并取消任务
372-
logger.error(f"Error during streaming for request {request_id}: {e}")
373-
if engine_task and not engine_task.done():
374-
engine_task.cancel()
375-
try:
376-
await engine_task
377-
except asyncio.CancelledError:
378-
pass
379-
raise # 重新抛出异常
331+
332+
# 使用统一的流式处理函数
333+
async for chunk in _run_engine_with_streaming(
334+
engine, request_id, created, request.model, thinking_generator
335+
):
336+
yield chunk
380337

381338
# 发送结束标记
382339
chunk_data = {

api/v1/models.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,11 @@
66
from typing import List, Dict, Any
77

88
from config import config
9+
from .chat import verify_auth # 复用,别重复造轮子
910

1011
router = APIRouter()
1112

1213

13-
def verify_auth(authorization: str = Header(None)) -> bool:
14-
"""验证 API 密钥"""
15-
if not config.api_key:
16-
return True
17-
18-
if not authorization:
19-
raise HTTPException(status_code=401, detail="Missing authorization header")
20-
21-
token = authorization.replace("Bearer ", "").strip()
22-
23-
if not config.validate_api_key(token):
24-
raise HTTPException(status_code=401, detail="Invalid API key")
25-
26-
return True
27-
28-
2914
@router.get("/v1/models")
3015
async def list_models(authorization: str = Header(None)):
3116
"""

0 commit comments

Comments
 (0)