Skip to content

Commit 49165d6

Browse files
authored
Refactor tool message formatting and enhance error handling in all_tools.py; update nest_asyncio dependency in pyproject.toml (#5408)
1 parent 260dda8 commit 49165d6

File tree

2 files changed

+38
-27
lines changed

2 files changed

+38
-27
lines changed

libs/chatchat-server/langchain_chatchat/agents/format_scratchpad/all_tools.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
from langchain_core.agents import AgentAction
77
from langchain_core.messages import (
88
AIMessage,
9+
HumanMessage,
910
BaseMessage,
1011
ToolMessage,
1112
)
12-
1313
from langchain_chatchat.agent_toolkits import BaseToolOutput
1414
from langchain_chatchat.agent_toolkits.all_tools.code_interpreter_tool import (
1515
CodeInterpreterToolOutput,
@@ -53,7 +53,6 @@ def _create_tool_message(
5353
additional_kwargs={"name": agent_action.tool},
5454
)
5555

56-
5756
def format_to_platform_tool_messages(
5857
intermediate_steps: Sequence[Tuple[AgentAction, BaseToolOutput]],
5958
) -> List[BaseMessage]:
@@ -64,63 +63,74 @@ def format_to_platform_tool_messages(
6463
6564
Returns:
6665
list of messages to send to the LLM for the next prediction
67-
6866
"""
6967
messages = []
70-
for agent_action, observation in intermediate_steps:
68+
69+
for idx, (agent_action, observation) in enumerate(intermediate_steps):
70+
# === CodeInterpreter ===
7171
if isinstance(agent_action, CodeInterpreterAgentAction):
7272
if isinstance(observation, CodeInterpreterToolOutput):
73-
if "auto" == observation.platform_params.get("sandbox", "auto"):
73+
sandbox_type = observation.platform_params.get("sandbox", "auto")
74+
if sandbox_type == "auto":
7475
new_messages = [
7576
AIMessage(content=str(observation.code_input)),
7677
_create_tool_message(agent_action, observation),
7778
]
78-
79-
messages.extend(
80-
[new for new in new_messages if new not in messages]
81-
)
82-
elif "none" == observation.platform_params.get("sandbox", "auto"):
79+
elif sandbox_type == "none":
8380
new_messages = [
8481
AIMessage(content=str(observation.code_input)),
8582
_create_tool_message(agent_action, observation.code_output),
8683
]
87-
88-
messages.extend(
89-
[new for new in new_messages if new not in messages]
90-
)
9184
else:
92-
raise ValueError(
93-
f"Unknown sandbox type: {observation.platform_params.get('sandbox', 'auto')}"
94-
)
85+
raise ValueError(f"Unknown sandbox type: {sandbox_type}")
86+
messages.extend([m for m in new_messages if m not in messages])
9587
else:
9688
raise ValueError(f"Unknown observation type: {type(observation)}")
9789

90+
# === DrawingTool ===
9891
elif isinstance(agent_action, DrawingToolAgentAction):
9992
if isinstance(observation, DrawingToolOutput):
100-
new_messages = [AIMessage(content=str(observation))]
101-
messages.extend([new for new in new_messages if new not in messages])
93+
messages.append(AIMessage(content=str(observation)))
10294
else:
10395
raise ValueError(f"Unknown observation type: {type(observation)}")
10496

97+
# === WebBrowser ===
10598
elif isinstance(agent_action, WebBrowserAgentAction):
10699
if isinstance(observation, WebBrowserToolOutput):
107-
new_messages = [AIMessage(content=str(observation))]
108-
messages.extend([new for new in new_messages if new not in messages])
100+
messages.append(AIMessage(content=str(observation)))
109101
else:
110102
raise ValueError(f"Unknown observation type: {type(observation)}")
111103

104+
# === ToolAgentAction ===
112105
elif isinstance(agent_action, ToolAgentAction):
113106
ai_msgs = AIMessage(
114-
content=f"arguments='{agent_action.tool_input}', name='{agent_action.tool}'"
107+
content=f"arguments='{agent_action.tool_input}', name='{agent_action.tool}'",
108+
additional_kwargs={
109+
"tool_calls": [
110+
{
111+
"index": idx,
112+
"id": agent_action.tool_call_id,
113+
"type": "function",
114+
"function": {
115+
"name": agent_action.tool,
116+
"arguments": json.dumps(agent_action.tool_input, ensure_ascii=False),
117+
},
118+
}
119+
]
120+
},
115121
)
116-
new_messages = [ai_msgs, _create_tool_message(agent_action, observation)]
117-
messages.extend([new for new in new_messages if new not in messages])
122+
messages.extend([ai_msgs, _create_tool_message(agent_action, observation)])
123+
124+
# === Generic AgentAction ===
118125
elif isinstance(agent_action, AgentAction):
126+
# 这里假设 observation 是本项目自定义prompt tools,而不是 模型测tools
119127
ai_msgs = AIMessage(
120128
content=f"{agent_action.log}"
121129
)
122-
new_messages = [ai_msgs, _create_tool_message(agent_action, observation)]
123-
messages.extend([new for new in new_messages if new not in messages])
130+
messages.extend([ai_msgs, HumanMessage(content=str(observation))])
131+
132+
# === Fallback ===
124133
else:
125-
messages.append(AIMessage(content=agent_action.log))
134+
messages.append(AIMessage(content=getattr(agent_action, "log", str(agent_action))))
135+
126136
return messages

libs/chatchat-server/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ langchain-community = "0.0.36"
2020
langchain-openai = { version = "0.0.6", python = ">=3.8.1,<3.12,!=3.9.7" }
2121
langchain-experimental = "0.0.58"
2222
humanlayer= "0.7.6"
23+
nest_asyncio = "1.6.0"
2324
mcp = ">=1.4.1,<1.5"
2425
fastapi = "~0.109.2"
2526
sse_starlette = "~1.8.2"

0 commit comments

Comments
 (0)