Skip to content

Commit 7193454

Browse files
authored
feat: enhance WecomAIBotAdapter and WecomAIBotMessageEvent for improved streaming message handling (#5000)
fixes: #3965
1 parent d204b92 commit 7193454

2 files changed

Lines changed: 30 additions & 4 deletions

File tree

astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def __init__(
127127
self.queue_mgr,
128128
self._handle_queued_message,
129129
)
130+
self._stream_plain_cache: dict[str, str] = {}
130131

131132
self.webhook_client: WecomAIBotWebhookClient | None = None
132133
if self.msg_push_webhook_url:
@@ -198,6 +199,7 @@ async def _process_message(
198199
# wechat server is requesting for updates of a stream
199200
stream_id = message_data["stream"]["id"]
200201
if not self.queue_mgr.has_back_queue(stream_id):
202+
self._stream_plain_cache.pop(stream_id, None)
201203
if self.queue_mgr.is_stream_finished(stream_id):
202204
logger.debug(
203205
f"Stream already finished, returning end message: {stream_id}"
@@ -225,24 +227,48 @@ async def _process_message(
225227
return None
226228

227229
# aggregate all delta chains in the back queue
228-
latest_plain_content = ""
230+
cached_plain_content = self._stream_plain_cache.get(stream_id, "")
231+
latest_plain_content = cached_plain_content
229232
image_base64 = []
230233
finish = False
231234
while not queue.empty():
232235
msg = await queue.get()
233236
if msg["type"] == "plain":
234-
latest_plain_content = msg["data"] or ""
237+
plain_data = msg.get("data") or ""
238+
if msg.get("streaming", False):
239+
# streaming plain payload is already cumulative
240+
cached_plain_content = plain_data
241+
else:
242+
# segmented non-stream send() pushes plain chunks, needs append
243+
cached_plain_content += plain_data
244+
latest_plain_content = cached_plain_content
235245
elif msg["type"] == "image":
236246
image_base64.append(msg["image_data"])
247+
elif msg["type"] == "break":
248+
continue
237249
elif msg["type"] in {"end", "complete"}:
238250
# stream end
239251
finish = True
240252
self.queue_mgr.remove_queues(stream_id, mark_finished=True)
253+
self._stream_plain_cache.pop(stream_id, None)
241254
break
242255

243256
logger.debug(
244257
f"Aggregated content: {latest_plain_content}, image: {len(image_base64)}, finish: {finish}",
245258
)
259+
if not finish:
260+
self._stream_plain_cache[stream_id] = cached_plain_content
261+
if finish and not latest_plain_content and not image_base64:
262+
end_message = WecomAIBotStreamMessageBuilder.make_text_stream(
263+
stream_id,
264+
"",
265+
True,
266+
)
267+
return await self.api_client.encrypt_message(
268+
end_message,
269+
callback_params["nonce"],
270+
callback_params["timestamp"],
271+
)
246272
if latest_plain_content or image_base64:
247273
msg_items = []
248274
if finish and image_base64:

astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ async def send_streaming(self, generator, use_fallback=False) -> None:
186186
"type": "break", # break means a segment end
187187
"data": final_data,
188188
"streaming": True,
189-
"session_id": self.session_id,
189+
"session_id": stream_id,
190190
},
191191
)
192192
final_data = ""
@@ -205,7 +205,7 @@ async def send_streaming(self, generator, use_fallback=False) -> None:
205205
"type": "complete", # complete means we return the final result
206206
"data": final_data,
207207
"streaming": True,
208-
"session_id": self.session_id,
208+
"session_id": stream_id,
209209
},
210210
)
211211
await super().send_streaming(generator, use_fallback)

0 commit comments

Comments
 (0)