Skip to content

Commit ce602cb

Browse files
ayushag-nv2ez4bz
andauthored
refactor: New config types for tool calls (#4575) (#4857)
Signed-off-by: ayushag <[email protected]> Co-authored-by: William Zhang <[email protected]>
1 parent 13aa491 commit ce602cb

File tree

8 files changed

+514
-170
lines changed

8 files changed

+514
-170
lines changed

lib/llm/src/protocols/openai/chat_completions/jail.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -887,14 +887,14 @@ impl JailedStreamBuilder {
887887
if let Some(config) = parser_map.get(parser_name.as_str()) {
888888
// Auto-populate start sequences if none configured
889889
if self.jail_start_sequences.is_empty() {
890-
self.jail_start_sequences = config.json.tool_call_start_tokens.clone();
890+
self.jail_start_sequences = config.parser_config.tool_call_start_tokens();
891891
}
892892

893893
// Auto-populate end sequences if none configured
894894
if self.jail_end_sequences.is_empty() {
895895
self.jail_end_sequences = config
896-
.json
897-
.tool_call_end_tokens
896+
.parser_config
897+
.tool_call_end_tokens()
898898
.iter()
899899
.filter(|&s| !s.is_empty())
900900
.cloned()
@@ -914,7 +914,7 @@ impl JailedStreamBuilder {
914914
let parser_map = get_tool_parser_map();
915915
if let Some(config) = parser_map.get(parser_name.as_str()) {
916916
// Add start tokens from the parser config
917-
all_patterns.extend(config.json.tool_call_start_tokens.clone());
917+
all_patterns.extend(config.parser_config.tool_call_start_tokens());
918918
}
919919
}
920920

lib/llm/tests/test_jail.rs

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1935,6 +1935,239 @@ mod tests {
19351935
}
19361936
}
19371937

1938+
#[tokio::test]
1939+
async fn test_jailed_stream_qwen3_coder_parser() {
1940+
// Input:
1941+
// "I'll call a function. "
1942+
// + "<tool_call><function=get_weather><parameter=location>San Francisco</parameter><parameter=unit>celsius</parameter></function></tool_call>"
1943+
// + " Done."
1944+
// Expected output: 3 chunks [Content(), ToolCall(), Content()]
1945+
let chunks = vec![
1946+
create_mock_response_chunk("I'll call a function. ".to_string(), 0),
1947+
create_mock_response_chunk("<tool_call>".to_string(), 0),
1948+
create_mock_response_chunk("<function=get_weather>".to_string(), 0),
1949+
create_mock_response_chunk(
1950+
"<parameter=location>San Francisco</parameter>".to_string(),
1951+
0,
1952+
),
1953+
create_mock_response_chunk("<parameter=unit>celsius</parameter>".to_string(), 0),
1954+
create_mock_response_chunk("</function>".to_string(), 0),
1955+
create_mock_response_chunk("</tool_call>".to_string(), 0),
1956+
create_mock_response_chunk(" Done.".to_string(), 0),
1957+
];
1958+
1959+
let input_stream = stream::iter(chunks);
1960+
1961+
let jail = JailedStream::builder()
1962+
.tool_call_parser("qwen3_coder")
1963+
.build();
1964+
1965+
let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
1966+
1967+
assert_eq!(
1968+
results.len(),
1969+
3,
1970+
"Should have content, tool call, and trailing content"
1971+
);
1972+
1973+
// Verify exact output structure: [Content(), ToolCall(), Content()].
1974+
test_utils::assert_content(&results[0], "I'll call a function. ");
1975+
test_utils::assert_tool_call(
1976+
&results[1],
1977+
"get_weather",
1978+
serde_json::json!({"location": "San Francisco", "unit": "celsius"}),
1979+
);
1980+
test_utils::assert_content(&results[2], " Done.");
1981+
1982+
// Verify content reconstruction excludes tool calls.
1983+
let reconstructed = test_utils::reconstruct_content(&results);
1984+
assert_eq!(reconstructed, "I'll call a function. Done.");
1985+
}
1986+
1987+
#[tokio::test]
1988+
async fn test_jailed_stream_qwen3_coder_multiple_params() {
1989+
let chunks = vec![
1990+
create_mock_response_chunk("Let me search for that. ".to_string(), 0),
1991+
create_mock_response_chunk(
1992+
"<tool_call><function=web_search><parameter=query>Rust programming</parameter><parameter=max_results>10</parameter><parameter=filter>recent</parameter></function></tool_call>".to_string(),
1993+
0,
1994+
),
1995+
create_mock_response_chunk(" Searching now.".to_string(), 0),
1996+
];
1997+
1998+
let input_stream = stream::iter(chunks);
1999+
let jail = JailedStream::builder()
2000+
.tool_call_parser("qwen3_coder")
2001+
.build();
2002+
2003+
let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
2004+
2005+
assert_eq!(results.len(), 3, "Should have 3 chunks");
2006+
2007+
test_utils::assert_content(&results[0], "Let me search for that. ");
2008+
test_utils::assert_tool_call(
2009+
&results[1],
2010+
"web_search",
2011+
serde_json::json!({
2012+
"query": "Rust programming",
2013+
"max_results": 10,
2014+
"filter": "recent"
2015+
}),
2016+
);
2017+
test_utils::assert_content(&results[2], " Searching now.");
2018+
}
2019+
2020+
#[tokio::test]
2021+
async fn test_jailed_stream_xml_parser_config_tokens_auto_population() {
2022+
// Tests that parser config tokens are auto-populated when using `.tool_call_parser()`.
2023+
// This verifies the jail system reads `tool_call_start_token` and `tool_call_end_token`
2024+
// from the `qwen3_coder` parser config.
2025+
let chunks = vec![
2026+
create_mock_response_chunk("Before tool call. ".to_string(), 0),
2027+
create_mock_response_chunk("<tool_call>".to_string(), 0), // Default qwen3_coder token
2028+
create_mock_response_chunk("<function=get_weather>".to_string(), 0),
2029+
create_mock_response_chunk("<parameter=city>Seattle</parameter>".to_string(), 0),
2030+
create_mock_response_chunk("</function>".to_string(), 0),
2031+
create_mock_response_chunk("</tool_call>".to_string(), 0), // Default qwen3_coder token
2032+
create_mock_response_chunk(" After tool call.".to_string(), 0),
2033+
];
2034+
2035+
let input_stream = stream::iter(chunks);
2036+
2037+
// Create JailedStream using ONLY `.tool_call_parser()`.
2038+
// This should auto-populate jail sequences from the qwen3_coder config
2039+
let jail = JailedStream::builder()
2040+
.tool_call_parser("qwen3_coder")
2041+
.build();
2042+
2043+
let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
2044+
2045+
assert_eq!(
2046+
results.len(),
2047+
3,
2048+
"Should have content, tool call, and trailing content"
2049+
);
2050+
2051+
test_utils::assert_content(&results[0], "Before tool call. ");
2052+
test_utils::assert_tool_call(
2053+
&results[1],
2054+
"get_weather",
2055+
serde_json::json!({"city": "Seattle"}),
2056+
);
2057+
test_utils::assert_content(&results[2], " After tool call.");
2058+
2059+
let reconstructed = test_utils::reconstruct_content(&results);
2060+
assert_eq!(reconstructed, "Before tool call. After tool call.");
2061+
}
2062+
2063+
#[tokio::test]
2064+
async fn test_jailed_stream_xml_manual_sequences_prevent_auto_population() {
2065+
// Tests that manually setting jail sequences prevents auto-population.
2066+
// This verifies the builder respects manual configuration over auto-population.
2067+
//
2068+
// When custom sequences are set, the default parser tokens (<tool_call>) should
2069+
// NOT trigger jailing and should pass through as regular content.
2070+
let chunks = vec![
2071+
create_mock_response_chunk("Text with ".to_string(), 0),
2072+
// Default qwen3_coder token - should NOT trigger jailing.
2073+
create_mock_response_chunk("<tool_call>".to_string(), 0),
2074+
create_mock_response_chunk("should not jail".to_string(), 0),
2075+
create_mock_response_chunk("</tool_call>".to_string(), 0),
2076+
create_mock_response_chunk(" because custom ".to_string(), 0),
2077+
// Custom marker - this SHOULD trigger jailing since we register it below.
2078+
create_mock_response_chunk("[[START]]".to_string(), 0),
2079+
create_mock_response_chunk("jailed content".to_string(), 0),
2080+
create_mock_response_chunk("[[END]]".to_string(), 0),
2081+
create_mock_response_chunk(" text.".to_string(), 0),
2082+
];
2083+
2084+
let input_stream = stream::iter(chunks);
2085+
2086+
// Set custom jail sequences - this should prevent auto-population.
2087+
// The default <tool_call> tokens should NOT trigger jailing.
2088+
let jail = JailedStream::builder()
2089+
.jail_start_sequence("[[START]]")
2090+
.jail_end_sequence("[[END]]")
2091+
.tool_call_parser("qwen3_coder")
2092+
.build();
2093+
2094+
let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
2095+
2096+
// The exact number of chunks depends on emission mode (packed vs single-choice-per-chunk)
2097+
// but we can verify the key behaviors:
2098+
// 1. Default <tool_call> tokens pass through as content (not jailed)
2099+
// 2. Custom [[START]]/[[END]] markers trigger jailing
2100+
// 3. No tool calls are extracted (because jailed content isn't valid XML)
2101+
2102+
// Find chunk(s) containing the default tokens that passed through.
2103+
let default_token_chunks: Vec<_> = results
2104+
.iter()
2105+
.filter_map(|r| {
2106+
r.data
2107+
.as_ref()
2108+
.and_then(|d| d.choices.first())
2109+
.and_then(|c| c.delta.content.as_ref())
2110+
})
2111+
.filter(|content| {
2112+
content.contains("<tool_call>") || content.contains("should not jail")
2113+
})
2114+
.collect();
2115+
2116+
assert!(
2117+
!default_token_chunks.is_empty(),
2118+
"Default <tool_call> should pass through as content when manual sequences are set"
2119+
);
2120+
2121+
// Find chunk containing the jailed content that was released.
2122+
let jailed_chunk = results
2123+
.iter()
2124+
.filter_map(|r| {
2125+
r.data
2126+
.as_ref()
2127+
.and_then(|d| d.choices.first())
2128+
.and_then(|c| c.delta.content.as_ref())
2129+
})
2130+
.find(|content| content.contains("[[START]]") && content.contains("jailed content"));
2131+
2132+
assert!(
2133+
jailed_chunk.is_some(),
2134+
"Custom markers should trigger jailing and accumulated content should be released"
2135+
);
2136+
2137+
// Since the custom markers include non-XML content, the parser should not extract tool calls.
2138+
// The accumulated content "[[START]]jailed content[[END]]", although compatible with the
2139+
// way we configured `jail` above, is not consistent with what `qwen_coder` expects, and
2140+
// there is (at time of writing) no way to pass a parser instance - only a string that
2141+
// internally gets mapped to default way of instantiating a particular parser.
2142+
let tool_call_count = results
2143+
.iter()
2144+
.filter(|r| {
2145+
r.data
2146+
.as_ref()
2147+
.and_then(|d| d.choices.first())
2148+
.and_then(|c| c.delta.tool_calls.as_ref())
2149+
.map(|tc| !tc.is_empty())
2150+
.unwrap_or(false)
2151+
})
2152+
.count();
2153+
2154+
assert_eq!(
2155+
tool_call_count, 0,
2156+
"Should have 0 tool calls because jailed content doesn't match XML format"
2157+
);
2158+
2159+
// Verify content reconstruction - all original content should be preserved.
2160+
let reconstructed = test_utils::reconstruct_content(&results);
2161+
assert!(
2162+
reconstructed.contains("<tool_call>") && reconstructed.contains("should not jail"),
2163+
"Reconstructed content should include default tokens that passed through"
2164+
);
2165+
assert!(
2166+
reconstructed.contains("[[START]]") && reconstructed.contains("jailed content"),
2167+
"Reconstructed content should include jailed content with custom markers"
2168+
);
2169+
}
2170+
19382171
#[tokio::test]
19392172
async fn test_jailed_stream_mistral_false_positive_curly() {
19402173
// Curly brace in normal text should not trigger tool call detection for mistral

0 commit comments

Comments
 (0)