Skip to content

Commit 44e8600

Browse files
authored
refactor: New config types for tool calls (#4575)
* Why? We would like the ability to configure different parser types. Prior to this commit, only the JSON parser could be configured. * What? This commit refactors the tool parser config in the following ways: - the `format` and `json` fields of `ToolParserConfig` are merged into a single `config` field that is a "discriminated union" type. Each parser type can declare its own configuration options. - a `XmlParserConfig` is defined with a default factory method that corresponds to the Qwen3 coder configuration. - affected calls and tests are adjusted.
1 parent 262cce7 commit 44e8600

File tree

8 files changed

+514
-170
lines changed

8 files changed

+514
-170
lines changed

lib/llm/src/protocols/openai/chat_completions/jail.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -895,14 +895,14 @@ impl JailedStreamBuilder {
895895
if let Some(config) = parser_map.get(parser_name.as_str()) {
896896
// Auto-populate start sequences if none configured
897897
if self.jail_start_sequences.is_empty() {
898-
self.jail_start_sequences = config.json.tool_call_start_tokens.clone();
898+
self.jail_start_sequences = config.parser_config.tool_call_start_tokens();
899899
}
900900

901901
// Auto-populate end sequences if none configured
902902
if self.jail_end_sequences.is_empty() {
903903
self.jail_end_sequences = config
904-
.json
905-
.tool_call_end_tokens
904+
.parser_config
905+
.tool_call_end_tokens()
906906
.iter()
907907
.filter(|&s| !s.is_empty())
908908
.cloned()
@@ -922,7 +922,7 @@ impl JailedStreamBuilder {
922922
let parser_map = get_tool_parser_map();
923923
if let Some(config) = parser_map.get(parser_name.as_str()) {
924924
// Add start tokens from the parser config
925-
all_patterns.extend(config.json.tool_call_start_tokens.clone());
925+
all_patterns.extend(config.parser_config.tool_call_start_tokens());
926926
}
927927
}
928928

lib/llm/tests/test_jail.rs

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,6 +1989,239 @@ mod tests {
19891989
}
19901990
}
19911991

1992+
#[tokio::test]
1993+
async fn test_jailed_stream_qwen3_coder_parser() {
1994+
// Input:
1995+
// "I'll call a function. "
1996+
// + "<tool_call><function=get_weather><parameter=location>San Francisco</parameter><parameter=unit>celsius</parameter></function></tool_call>"
1997+
// + " Done."
1998+
// Expected output: 3 chunks [Content(), ToolCall(), Content()]
1999+
let chunks = vec![
2000+
create_mock_response_chunk("I'll call a function. ".to_string(), 0),
2001+
create_mock_response_chunk("<tool_call>".to_string(), 0),
2002+
create_mock_response_chunk("<function=get_weather>".to_string(), 0),
2003+
create_mock_response_chunk(
2004+
"<parameter=location>San Francisco</parameter>".to_string(),
2005+
0,
2006+
),
2007+
create_mock_response_chunk("<parameter=unit>celsius</parameter>".to_string(), 0),
2008+
create_mock_response_chunk("</function>".to_string(), 0),
2009+
create_mock_response_chunk("</tool_call>".to_string(), 0),
2010+
create_mock_response_chunk(" Done.".to_string(), 0),
2011+
];
2012+
2013+
let input_stream = stream::iter(chunks);
2014+
2015+
let jail = JailedStream::builder()
2016+
.tool_call_parser("qwen3_coder")
2017+
.build();
2018+
2019+
let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
2020+
2021+
assert_eq!(
2022+
results.len(),
2023+
3,
2024+
"Should have content, tool call, and trailing content"
2025+
);
2026+
2027+
// Verify exact output structure: [Content(), ToolCall(), Content()].
2028+
test_utils::assert_content(&results[0], "I'll call a function. ");
2029+
test_utils::assert_tool_call(
2030+
&results[1],
2031+
"get_weather",
2032+
serde_json::json!({"location": "San Francisco", "unit": "celsius"}),
2033+
);
2034+
test_utils::assert_content(&results[2], " Done.");
2035+
2036+
// Verify content reconstruction excludes tool calls.
2037+
let reconstructed = test_utils::reconstruct_content(&results);
2038+
assert_eq!(reconstructed, "I'll call a function. Done.");
2039+
}
2040+
2041+
#[tokio::test]
2042+
async fn test_jailed_stream_qwen3_coder_multiple_params() {
2043+
let chunks = vec![
2044+
create_mock_response_chunk("Let me search for that. ".to_string(), 0),
2045+
create_mock_response_chunk(
2046+
"<tool_call><function=web_search><parameter=query>Rust programming</parameter><parameter=max_results>10</parameter><parameter=filter>recent</parameter></function></tool_call>".to_string(),
2047+
0,
2048+
),
2049+
create_mock_response_chunk(" Searching now.".to_string(), 0),
2050+
];
2051+
2052+
let input_stream = stream::iter(chunks);
2053+
let jail = JailedStream::builder()
2054+
.tool_call_parser("qwen3_coder")
2055+
.build();
2056+
2057+
let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
2058+
2059+
assert_eq!(results.len(), 3, "Should have 3 chunks");
2060+
2061+
test_utils::assert_content(&results[0], "Let me search for that. ");
2062+
test_utils::assert_tool_call(
2063+
&results[1],
2064+
"web_search",
2065+
serde_json::json!({
2066+
"query": "Rust programming",
2067+
"max_results": 10,
2068+
"filter": "recent"
2069+
}),
2070+
);
2071+
test_utils::assert_content(&results[2], " Searching now.");
2072+
}
2073+
2074+
#[tokio::test]
2075+
async fn test_jailed_stream_xml_parser_config_tokens_auto_population() {
2076+
// Tests that parser config tokens are auto-populated when using `.tool_call_parser()`.
2077+
// This verifies the jail system reads `tool_call_start_token` and `tool_call_end_token`
2078+
// from the `qwen3_coder` parser config.
2079+
let chunks = vec![
2080+
create_mock_response_chunk("Before tool call. ".to_string(), 0),
2081+
create_mock_response_chunk("<tool_call>".to_string(), 0), // Default qwen3_coder token
2082+
create_mock_response_chunk("<function=get_weather>".to_string(), 0),
2083+
create_mock_response_chunk("<parameter=city>Seattle</parameter>".to_string(), 0),
2084+
create_mock_response_chunk("</function>".to_string(), 0),
2085+
create_mock_response_chunk("</tool_call>".to_string(), 0), // Default qwen3_coder token
2086+
create_mock_response_chunk(" After tool call.".to_string(), 0),
2087+
];
2088+
2089+
let input_stream = stream::iter(chunks);
2090+
2091+
// Create JailedStream using ONLY `.tool_call_parser()`.
2092+
// This should auto-populate jail sequences from the qwen3_coder config
2093+
let jail = JailedStream::builder()
2094+
.tool_call_parser("qwen3_coder")
2095+
.build();
2096+
2097+
let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
2098+
2099+
assert_eq!(
2100+
results.len(),
2101+
3,
2102+
"Should have content, tool call, and trailing content"
2103+
);
2104+
2105+
test_utils::assert_content(&results[0], "Before tool call. ");
2106+
test_utils::assert_tool_call(
2107+
&results[1],
2108+
"get_weather",
2109+
serde_json::json!({"city": "Seattle"}),
2110+
);
2111+
test_utils::assert_content(&results[2], " After tool call.");
2112+
2113+
let reconstructed = test_utils::reconstruct_content(&results);
2114+
assert_eq!(reconstructed, "Before tool call. After tool call.");
2115+
}
2116+
2117+
#[tokio::test]
2118+
async fn test_jailed_stream_xml_manual_sequences_prevent_auto_population() {
2119+
// Tests that manually setting jail sequences prevents auto-population.
2120+
// This verifies the builder respects manual configuration over auto-population.
2121+
//
2122+
// When custom sequences are set, the default parser tokens (<tool_call>) should
2123+
// NOT trigger jailing and should pass through as regular content.
2124+
let chunks = vec![
2125+
create_mock_response_chunk("Text with ".to_string(), 0),
2126+
// Default qwen3_coder token - should NOT trigger jailing.
2127+
create_mock_response_chunk("<tool_call>".to_string(), 0),
2128+
create_mock_response_chunk("should not jail".to_string(), 0),
2129+
create_mock_response_chunk("</tool_call>".to_string(), 0),
2130+
create_mock_response_chunk(" because custom ".to_string(), 0),
2131+
// Custom marker - this SHOULD trigger jailing since we register it below.
2132+
create_mock_response_chunk("[[START]]".to_string(), 0),
2133+
create_mock_response_chunk("jailed content".to_string(), 0),
2134+
create_mock_response_chunk("[[END]]".to_string(), 0),
2135+
create_mock_response_chunk(" text.".to_string(), 0),
2136+
];
2137+
2138+
let input_stream = stream::iter(chunks);
2139+
2140+
// Set custom jail sequences - this should prevent auto-population.
2141+
// The default <tool_call> tokens should NOT trigger jailing.
2142+
let jail = JailedStream::builder()
2143+
.jail_start_sequence("[[START]]")
2144+
.jail_end_sequence("[[END]]")
2145+
.tool_call_parser("qwen3_coder")
2146+
.build();
2147+
2148+
let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
2149+
2150+
// The exact number of chunks depends on emission mode (packed vs single-choice-per-chunk)
2151+
// but we can verify the key behaviors:
2152+
// 1. Default <tool_call> tokens pass through as content (not jailed)
2153+
// 2. Custom [[START]]/[[END]] markers trigger jailing
2154+
// 3. No tool calls are extracted (because jailed content isn't valid XML)
2155+
2156+
// Find chunk(s) containing the default tokens that passed through.
2157+
let default_token_chunks: Vec<_> = results
2158+
.iter()
2159+
.filter_map(|r| {
2160+
r.data
2161+
.as_ref()
2162+
.and_then(|d| d.choices.first())
2163+
.and_then(|c| c.delta.content.as_ref())
2164+
})
2165+
.filter(|content| {
2166+
content.contains("<tool_call>") || content.contains("should not jail")
2167+
})
2168+
.collect();
2169+
2170+
assert!(
2171+
!default_token_chunks.is_empty(),
2172+
"Default <tool_call> should pass through as content when manual sequences are set"
2173+
);
2174+
2175+
// Find chunk containing the jailed content that was released.
2176+
let jailed_chunk = results
2177+
.iter()
2178+
.filter_map(|r| {
2179+
r.data
2180+
.as_ref()
2181+
.and_then(|d| d.choices.first())
2182+
.and_then(|c| c.delta.content.as_ref())
2183+
})
2184+
.find(|content| content.contains("[[START]]") && content.contains("jailed content"));
2185+
2186+
assert!(
2187+
jailed_chunk.is_some(),
2188+
"Custom markers should trigger jailing and accumulated content should be released"
2189+
);
2190+
2191+
// Since the custom markers include non-XML content, the parser should not extract tool calls.
2192+
// The accumulated content "[[START]]jailed content[[END]]", although compatible with the
2193+
// way we configured `jail` above, is not consistent with what `qwen_coder` expects, and
2194+
// there is (at time of writing) no way to pass a parser instance - only a string that
2195+
// internally gets mapped to default way of instantiating a particular parser.
2196+
let tool_call_count = results
2197+
.iter()
2198+
.filter(|r| {
2199+
r.data
2200+
.as_ref()
2201+
.and_then(|d| d.choices.first())
2202+
.and_then(|c| c.delta.tool_calls.as_ref())
2203+
.map(|tc| !tc.is_empty())
2204+
.unwrap_or(false)
2205+
})
2206+
.count();
2207+
2208+
assert_eq!(
2209+
tool_call_count, 0,
2210+
"Should have 0 tool calls because jailed content doesn't match XML format"
2211+
);
2212+
2213+
// Verify content reconstruction - all original content should be preserved.
2214+
let reconstructed = test_utils::reconstruct_content(&results);
2215+
assert!(
2216+
reconstructed.contains("<tool_call>") && reconstructed.contains("should not jail"),
2217+
"Reconstructed content should include default tokens that passed through"
2218+
);
2219+
assert!(
2220+
reconstructed.contains("[[START]]") && reconstructed.contains("jailed content"),
2221+
"Reconstructed content should include jailed content with custom markers"
2222+
);
2223+
}
2224+
19922225
#[tokio::test]
19932226
async fn test_jailed_stream_mistral_false_positive_curly() {
19942227
// Curly brace in normal text should not trigger tool call detection for mistral

0 commit comments

Comments
 (0)