diff --git a/lib/llm/src/protocols/openai/completions.rs b/lib/llm/src/protocols/openai/completions.rs index b62f801c40..3971405339 100644 --- a/lib/llm/src/protocols/openai/completions.rs +++ b/lib/llm/src/protocols/openai/completions.rs @@ -238,7 +238,11 @@ impl OpenAIStopConditionsProvider for NvCreateCompletionRequest { } fn get_stop(&self) -> Option> { - None + use dynamo_async_openai::types::Stop; + self.inner.stop.as_ref().map(|s| match s { + Stop::String(s) => vec![s.clone()], + Stop::StringArray(arr) => arr.clone(), + }) } fn nvext(&self) -> Option<&NvExt> { @@ -494,4 +498,33 @@ mod tests { assert_eq!(output_options.skip_special_tokens, Some(skip_value)); } } + + #[test] + fn test_stop() { + let null_stop = json!({ + "model": "test-model", + "prompt": "Hello, world!" + }); + let request: NvCreateCompletionRequest = + serde_json::from_value(null_stop).expect("Failed to deserialize request"); + assert_eq!(request.get_stop(), None); + + let one_stop = json!({ + "model": "test-model", + "prompt": "Hello, world!", + "stop": "foo" + }); + let request: NvCreateCompletionRequest = + serde_json::from_value(one_stop).expect("Failed to deserialize request"); + assert_eq!(request.get_stop(), Some(vec!["foo".to_string()])); + + let many_stops = json!({ + "model": "test-model", + "prompt": "Hello, world!", + "stop": ["foo", "bar"] + }); + let request: NvCreateCompletionRequest = + serde_json::from_value(many_stops).expect("Failed to deserialize request"); + assert_eq!(request.get_stop(), Some(vec!["foo".to_string(), "bar".to_string()])); + } }