Skip to content

Commit 773834b

Browse files
committed
format, collaspe the if branches
Signed-off-by: [email protected] <[email protected]>
1 parent 89ef073 commit 773834b

File tree

1 file changed

+51
-52
lines changed

1 file changed

+51
-52
lines changed

lib/llm/src/http/service/openai.rs

Lines changed: 51 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -735,14 +735,16 @@ fn extract_backend_error_if_present<T: serde::Serialize>(
735735
if let Some(event_type) = &event.event
736736
&& event_type == "error"
737737
{
738-
let comment_str = event.comment
738+
let comment_str = event
739+
.comment
739740
.as_ref()
740741
.map(|c| c.join(", "))
741742
.unwrap_or_else(|| "Unknown error".to_string());
742743

743744
// Try to parse comment as error JSON to extract status code
744745
if let Ok(error_payload) = serde_json::from_str::<ErrorPayload>(&comment_str) {
745-
let code = error_payload.code
746+
let code = error_payload
747+
.code
746748
.and_then(|c| StatusCode::from_u16(c).ok())
747749
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
748750
let message = error_payload.message.unwrap_or(comment_str);
@@ -753,21 +755,17 @@ fn extract_backend_error_if_present<T: serde::Serialize>(
753755
}
754756

755757
// Check if the data payload itself contains an error structure with code >= 400
756-
if let Some(data) = &event.data {
757-
if let Ok(json_value) = serde_json::to_value(data) {
758-
if let Ok(error_payload) = serde_json::from_value::<ErrorPayload>(json_value.clone()) {
759-
if let Some(code_num) = error_payload.code {
760-
if code_num >= 400 {
761-
let code = StatusCode::from_u16(code_num)
762-
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
763-
let message = error_payload.message.unwrap_or_else(|| {
764-
json_value.to_string()
765-
});
766-
return Some((message, code));
767-
}
768-
}
769-
}
770-
}
758+
if let Some(data) = &event.data
759+
&& let Ok(json_value) = serde_json::to_value(data)
760+
&& let Ok(error_payload) = serde_json::from_value::<ErrorPayload>(json_value.clone())
761+
&& let Some(code_num) = error_payload.code
762+
&& code_num >= 400
763+
{
764+
let code = StatusCode::from_u16(code_num).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
765+
let message = error_payload
766+
.message
767+
.unwrap_or_else(|| json_value.to_string());
768+
return Some((message, code));
771769
}
772770

773771
// Check if comment contains error information (without event: error)
@@ -777,19 +775,18 @@ fn extract_backend_error_if_present<T: serde::Serialize>(
777775
let comment_str = comments.join(", ");
778776

779777
// Try to parse comment as error JSON with code >= 400
780-
if let Ok(error_payload) = serde_json::from_str::<ErrorPayload>(&comment_str) {
781-
if let Some(code_num) = error_payload.code {
782-
if code_num >= 400 {
783-
let code = StatusCode::from_u16(code_num)
784-
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
785-
let message = error_payload.message.unwrap_or(comment_str);
786-
return Some((message, code));
787-
}
788-
}
778+
if let Ok(error_payload) = serde_json::from_str::<ErrorPayload>(&comment_str)
779+
&& let Some(code_num) = error_payload.code
780+
&& code_num >= 400
781+
{
782+
let code = StatusCode::from_u16(code_num).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
783+
let message = error_payload.message.unwrap_or(comment_str);
784+
return Some((message, code));
789785
}
790786

791-
// Comments present with no data often indicates error
792-
if event.data.is_none() {
787+
// Comments present with no data AND no event type indicates error
788+
// (events with event types like "request_id" or "event.dynamo.test.sentinel" are annotations)
789+
if event.data.is_none() && event.event.is_none() {
793790
return Some((comment_str, StatusCode::INTERNAL_SERVER_ERROR));
794791
}
795792
}
@@ -937,17 +934,11 @@ async fn chat_completions(
937934
// todo - tap the stream and propagate request level metrics
938935
// note - we might do this as part of the post processing set to make it more generic
939936

940-
// Check first event for backend errors before streaming
941-
let stream_with_check = check_for_backend_error(stream).await.map_err(|error_response| {
942-
tracing::error!(request_id, "Backend error detected: {:?}", error_response);
943-
error_response
944-
})?;
945-
946937
if streaming {
947938
stream_handle.arm(); // allows the system to detect client disconnects and cancel the LLM generation
948939

949940
let mut http_queue_guard = Some(http_queue_guard);
950-
let stream = stream_with_check.map(move |response| {
941+
let stream = stream.map(move |response| {
951942
// Calls observe_response() on each token
952943
process_response_using_event_converter_and_observe_metrics(
953944
EventConverter::from(response),
@@ -965,6 +956,15 @@ async fn chat_completions(
965956

966957
Ok(sse_stream.into_response())
967958
} else {
959+
// Check first event for backend errors before aggregating (non-streaming only)
960+
let stream_with_check =
961+
check_for_backend_error(stream)
962+
.await
963+
.map_err(|error_response| {
964+
tracing::error!(request_id, "Backend error detected: {:?}", error_response);
965+
error_response
966+
})?;
967+
968968
let mut http_queue_guard = Some(http_queue_guard);
969969
let stream = stream_with_check.inspect(move |response| {
970970
// Calls observe_response() on each token - drops http_queue_guard on first token
@@ -975,22 +975,20 @@ async fn chat_completions(
975975
);
976976
});
977977

978-
let response = NvCreateChatCompletionResponse::from_annotated_stream(
979-
stream,
980-
parsing_options.clone(),
981-
)
982-
.await
983-
.map_err(|e| {
984-
tracing::error!(
985-
request_id,
986-
"Failed to parse chat completion response: {:?}",
987-
e
988-
);
989-
ErrorMessage::internal_server_error(&format!(
990-
"Failed to parse chat completion response: {}",
991-
e
992-
))
993-
})?;
978+
let response =
979+
NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options.clone())
980+
.await
981+
.map_err(|e| {
982+
tracing::error!(
983+
request_id,
984+
"Failed to parse chat completion response: {:?}",
985+
e
986+
);
987+
ErrorMessage::internal_server_error(&format!(
988+
"Failed to parse chat completion response: {}",
989+
e
990+
))
991+
})?;
994992

995993
inflight_guard.mark_ok();
996994
Ok(Json(response).into_response())
@@ -2210,7 +2208,8 @@ mod tests {
22102208
use futures::stream;
22112209

22122210
// Create an error event with JSON payload containing error code in comment
2213-
let error_json = r#"{"message":"prompt > max_seq_len","type":"Internal Server Error","code":500}"#;
2211+
let error_json =
2212+
r#"{"message":"prompt > max_seq_len","type":"Internal Server Error","code":500}"#;
22142213
let error_event = Annotated::<NvCreateChatCompletionStreamResponse> {
22152214
data: None,
22162215
id: None,

0 commit comments

Comments
 (0)