Skip to content

Commit 9606c45

Browse files
authored
fix: Cherrypick PR#4782 to 0.7.1 -- have preprocessor populate "stop" field (#4858)
Signed-off-by: Qi Wang <[email protected]>
1 parent da6dc88 commit 9606c45

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

lib/llm/src/protocols/openai/completions.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,11 @@ impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
238238
}
239239

240240
fn get_stop(&self) -> Option<Vec<String>> {
241-
None
241+
use dynamo_async_openai::types::Stop;
242+
self.inner.stop.as_ref().map(|s| match s {
243+
Stop::String(s) => vec![s.clone()],
244+
Stop::StringArray(arr) => arr.clone(),
245+
})
242246
}
243247

244248
fn nvext(&self) -> Option<&NvExt> {
@@ -493,4 +497,36 @@ mod tests {
493497
assert_eq!(output_options.skip_special_tokens, Some(skip_value));
494498
}
495499
}
500+
501+
#[test]
502+
fn test_stop() {
503+
let null_stop = json!({
504+
"model": "test-model",
505+
"prompt": "Hello, world!"
506+
});
507+
let request: NvCreateCompletionRequest =
508+
serde_json::from_value(null_stop).expect("Failed to deserialize request");
509+
assert_eq!(request.get_stop(), None);
510+
511+
let one_stop = json!({
512+
"model": "test-model",
513+
"prompt": "Hello, world!",
514+
"stop": "foo"
515+
});
516+
let request: NvCreateCompletionRequest =
517+
serde_json::from_value(one_stop).expect("Failed to deserialize request");
518+
assert_eq!(request.get_stop(), Some(vec!["foo".to_string()]));
519+
520+
let many_stops = json!({
521+
"model": "test-model",
522+
"prompt": "Hello, world!",
523+
"stop": ["foo", "bar"]
524+
});
525+
let request: NvCreateCompletionRequest =
526+
serde_json::from_value(many_stops).expect("Failed to deserialize request");
527+
assert_eq!(
528+
request.get_stop(),
529+
Some(vec!["foo".to_string(), "bar".to_string()])
530+
);
531+
}
496532
}

0 commit comments

Comments
 (0)