Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(
self.queue_mgr,
self._handle_queued_message,
)
self._stream_plain_cache: dict[str, str] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (performance): 请考虑 _stream_plain_cache 的生命周期以及其潜在的无限增长问题。

目前只有在流结束且 has_back_queue(stream_id) 为 false 时才会移除条目。如果某个流永远无法正常完成,并且其 back queue 从未被移除(例如生产者崩溃或未来逻辑变更),在长时间运行的进程中,这个 dict 可能会无限增长。建议将缓存清理与队列生命周期直接绑定(例如在 remove_queues 中或某个共享的 teardown 路径中),并/或增加保护措施,比如 TTL、最大大小限制或调试断言,以检测泄漏。

建议实现方式:

        )
        # Cache plain text for streaming responses; guarded by a max size and lifecycle-based cleanup
        self._stream_plain_cache: dict[str, str] = {}
        # Hard cap to avoid unbounded growth in long-running processes
        self._stream_plain_cache_max_size: int = 10_000

        self.webhook_client: WecomAIBotWebhookClient | None = None
        if self.msg_push_webhook_url:
            # Before handling the request, proactively clean up cache entries whose queues no longer exist.
            # This ties cache lifecycle to queue lifecycle and prevents silent leaks.
            if self._stream_plain_cache:
                stale_stream_ids: list[str] = []
                for cached_stream_id in list(self._stream_plain_cache.keys()):
                    if not self.queue_mgr.has_back_queue(cached_stream_id):
                        stale_stream_ids.append(cached_stream_id)

                for stale_stream_id in stale_stream_ids:
                    self._stream_plain_cache.pop(stale_stream_id, None)

                # Safeguard: enforce a hard cap on cache size.
                if len(self._stream_plain_cache) > self._stream_plain_cache_max_size:
                    logger.warning(
                        "wecom_ai_bot _stream_plain_cache exceeded max size (%d > %d); "
                        "clearing cached entries to stay within bounds",
                        len(self._stream_plain_cache),
                        self._stream_plain_cache_max_size,
                    )
                    # Drop arbitrary entries until under the limit.
                    # We use list(...) to avoid RuntimeError from dict size change during iteration.
                    for cached_stream_id in list(self._stream_plain_cache.keys()):
                        self._stream_plain_cache.pop(cached_stream_id, None)
                        if len(self._stream_plain_cache) <= self._stream_plain_cache_max_size:
                            break

            # wechat server is requesting for updates of a stream
            stream_id = message_data["stream"]["id"]
            if not self.queue_mgr.has_back_queue(stream_id):
                # If the queue for this stream no longer exists, clear any residual cache for it。
                self._stream_plain_cache.pop(stream_id, None)
                if self.queue_mgr.is_stream_finished(stream_id):
                    logger.debug(
                        "Stream already finished, returning end message: %s",
                        stream_id,
                    )
                    return None

为了在整个代码库中让缓存生命周期与队列生命周期完全对齐,可以考虑:

  1. 在任何会永久销毁给定 stream_id 的队列的方法中(可能在本文件之外),例如 queue_mgr.remove_queues(stream_id) 或类似方法,同时移除对应的缓存条目:
    self._stream_plain_cache.pop(stream_id, None)
  2. 确保所有对 _stream_plain_cache 的写入都位于这个类中,并使用一致的模式,这样未来如果需要改造成更复杂的 TTL 或 LRU 策略,可以集中管理。如果写入点分散在代码库各处,建议将其重构到一个辅助方法中,例如 _set_stream_plain_cache(stream_id, value),在其中统一执行限制和日志记录。
Original comment in English

suggestion (performance): Consider lifecycle and potential unbounded growth of _stream_plain_cache.

Right now entries are only removed when a stream finishes and has_back_queue(stream_id) is false. If a stream never reaches completion and its back queue is never removed (e.g., producer crash or future logic changes), this dict can grow without bound in long‑running processes. Consider tying cache cleanup directly to queue lifecycle (e.g., in remove_queues or a shared teardown path), and/or adding a safeguard such as TTL, max size, or debug assertions to detect leaks.

Suggested implementation:

        )
        # Cache plain text for streaming responses; guarded by a max size and lifecycle-based cleanup
        self._stream_plain_cache: dict[str, str] = {}
        # Hard cap to avoid unbounded growth in long-running processes
        self._stream_plain_cache_max_size: int = 10_000

        self.webhook_client: WecomAIBotWebhookClient | None = None
        if self.msg_push_webhook_url:
            # Before handling the request, proactively clean up cache entries whose queues no longer exist.
            # This ties cache lifecycle to queue lifecycle and prevents silent leaks.
            if self._stream_plain_cache:
                stale_stream_ids: list[str] = []
                for cached_stream_id in list(self._stream_plain_cache.keys()):
                    if not self.queue_mgr.has_back_queue(cached_stream_id):
                        stale_stream_ids.append(cached_stream_id)

                for stale_stream_id in stale_stream_ids:
                    self._stream_plain_cache.pop(stale_stream_id, None)

                # Safeguard: enforce a hard cap on cache size.
                if len(self._stream_plain_cache) > self._stream_plain_cache_max_size:
                    logger.warning(
                        "wecom_ai_bot _stream_plain_cache exceeded max size (%d > %d); "
                        "clearing cached entries to stay within bounds",
                        len(self._stream_plain_cache),
                        self._stream_plain_cache_max_size,
                    )
                    # Drop arbitrary entries until under the limit.
                    # We use list(...) to avoid RuntimeError from dict size change during iteration.
                    for cached_stream_id in list(self._stream_plain_cache.keys()):
                        self._stream_plain_cache.pop(cached_stream_id, None)
                        if len(self._stream_plain_cache) <= self._stream_plain_cache_max_size:
                            break

            # wechat server is requesting for updates of a stream
            stream_id = message_data["stream"]["id"]
            if not self.queue_mgr.has_back_queue(stream_id):
                # If the queue for this stream no longer exists, clear any residual cache for it.
                self._stream_plain_cache.pop(stream_id, None)
                if self.queue_mgr.is_stream_finished(stream_id):
                    logger.debug(
                        "Stream already finished, returning end message: %s",
                        stream_id,
                    )
                    return None

To fully align cache lifecycle with queue lifecycle across the codebase, consider:

  1. In any method (possibly outside this file) that permanently tears down queues for a given stream_id (e.g., queue_mgr.remove_queues(stream_id) or similar), also removing the corresponding cache entry:
    self._stream_plain_cache.pop(stream_id, None)
  2. Ensuring all write sites for _stream_plain_cache reside in this class and use a consistent pattern so that future lifecycle changes (like a more sophisticated TTL or LRU strategy) can be centralized. If writes are scattered across the codebase, refactor them into a helper like _set_stream_plain_cache(stream_id, value) that can enforce limits and logging in one place.


self.webhook_client: WecomAIBotWebhookClient | None = None
if self.msg_push_webhook_url:
Expand Down Expand Up @@ -198,6 +199,7 @@ async def _process_message(
# wechat server is requesting for updates of a stream
stream_id = message_data["stream"]["id"]
if not self.queue_mgr.has_back_queue(stream_id):
self._stream_plain_cache.pop(stream_id, None)
if self.queue_mgr.is_stream_finished(stream_id):
logger.debug(
f"Stream already finished, returning end message: {stream_id}"
Expand Down Expand Up @@ -225,24 +227,48 @@ async def _process_message(
return None

# aggregate all delta chains in the back queue
latest_plain_content = ""
cached_plain_content = self._stream_plain_cache.get(stream_id, "")
latest_plain_content = cached_plain_content
image_base64 = []
finish = False
while not queue.empty():
msg = await queue.get()
if msg["type"] == "plain":
latest_plain_content = msg["data"] or ""
plain_data = msg.get("data") or ""
if msg.get("streaming", False):
# streaming plain payload is already cumulative
cached_plain_content = plain_data
else:
# segmented non-stream send() pushes plain chunks, needs append
cached_plain_content += plain_data
latest_plain_content = cached_plain_content
elif msg["type"] == "image":
image_base64.append(msg["image_data"])
elif msg["type"] == "break":
continue
elif msg["type"] in {"end", "complete"}:
# stream end
finish = True
self.queue_mgr.remove_queues(stream_id, mark_finished=True)
self._stream_plain_cache.pop(stream_id, None)
break

logger.debug(
f"Aggregated content: {latest_plain_content}, image: {len(image_base64)}, finish: {finish}",
)
if not finish:
self._stream_plain_cache[stream_id] = cached_plain_content
if finish and not latest_plain_content and not image_base64:
end_message = WecomAIBotStreamMessageBuilder.make_text_stream(
stream_id,
"",
True,
)
return await self.api_client.encrypt_message(
end_message,
callback_params["nonce"],
callback_params["timestamp"],
)
if latest_plain_content or image_base64:
msg_items = []
if finish and image_base64:
Expand Down
4 changes: 2 additions & 2 deletions astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ async def send_streaming(self, generator, use_fallback=False) -> None:
"type": "break", # break means a segment end
"data": final_data,
"streaming": True,
"session_id": self.session_id,
"session_id": stream_id,
},
)
final_data = ""
Expand All @@ -205,7 +205,7 @@ async def send_streaming(self, generator, use_fallback=False) -> None:
"type": "complete", # complete means we return the final result
"data": final_data,
"streaming": True,
"session_id": self.session_id,
"session_id": stream_id,
},
)
await super().send_streaming(generator, use_fallback)
Loading