From bb77f61aafd523fd713fd0ea09ff316fb6b2cc51 Mon Sep 17 00:00:00 2001 From: declark1 Date: Tue, 1 Oct 2024 13:37:38 -0700 Subject: [PATCH 01/50] Add initial client refactor code (wip) Signed-off-by: declark1 --- Cargo.lock | 2 + Cargo.toml | 1 + config/config.yaml | 10 +- config/test.config.yaml | 10 +- src/clients.rs | 560 +++++++++------------- src/clients/chunker.rs | 94 ++-- src/clients/detector.rs | 301 +----------- src/clients/detector/text_contents.rs | 145 ++++++ src/clients/detector/text_context_chat.rs | 40 ++ src/clients/detector/text_context_doc.rs | 110 +++++ src/clients/detector/text_generation.rs | 88 ++++ src/clients/errors.rs | 91 ++++ src/clients/generation.rs | 43 +- src/clients/http.rs | 182 +++++++ src/clients/nlp.rs | 107 ++--- src/clients/openai.rs | 423 ++++++++++++++++ src/clients/tgis.rs | 123 ++--- src/config.rs | 124 ++++- src/health.rs | 276 ++++------- src/lib.rs | 2 +- src/models.rs | 13 + src/orchestrator.rs | 208 +++++--- src/orchestrator/streaming.rs | 28 +- src/orchestrator/unary.rs | 165 ++++--- src/server.rs | 19 +- tests/test.config.yaml | 3 +- 26 files changed, 1984 insertions(+), 1184 deletions(-) create mode 100644 src/clients/detector/text_contents.rs create mode 100644 src/clients/detector/text_context_chat.rs create mode 100644 src/clients/detector/text_context_doc.rs create mode 100644 src/clients/detector/text_generation.rs create mode 100644 src/clients/errors.rs create mode 100644 src/clients/http.rs create mode 100644 src/clients/openai.rs diff --git a/Cargo.lock b/Cargo.lock index 72ede252..f33e376d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -606,6 +606,7 @@ dependencies = [ "ginepro", "hyper", "hyper-util", + "indexmap 2.5.0", "mio", "prost", "reqwest", @@ -1045,6 +1046,7 @@ checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5" dependencies = [ "equivalent", "hashbrown 0.14.5", + "serde", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index b4ae3e1e..c1079f6b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ url = "2.5.2" uuid = { version = "1.10.0", features = ["v4", "fast-rng"] } async-trait = "0.1.81" async-stream = "0.3.5" +indexmap = { version = "2.5.0", features = ["serde"] } [build-dependencies] tonic-build = "0.12.1" diff --git a/config/config.yaml b/config/config.yaml index 2fe6a657..4f13ea18 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -11,6 +11,12 @@ generation: service: hostname: localhost port: 8033 +# Generation server used for chat endpoints +# chat_generation: +# provider: openai +# service: +# hostname: http://localhost +# port: 8080 # Any chunker servers that will be used by any detectors chunkers: # Chunker ID/name @@ -26,8 +32,10 @@ chunkers: detectors: # Detector ID/name to be used in user requests hap-en: + # Detector type (text_contents, text_generation, text_context_chat, text_context_doc) + type: text_contents service: - hostname: https://localhost/api/v1/text/contents # Full url / endpoint currently expected + hostname: https://localhost port: 8080 # TLS ID/name, optional (detailed in `tls` section) tls: detector diff --git a/config/test.config.yaml b/config/test.config.yaml index f32a0a26..0decc09a 100644 --- a/config/test.config.yaml +++ b/config/test.config.yaml @@ -3,6 +3,11 @@ generation: service: hostname: localhost port: 443 +# chat_generation: +# provider: openai +# service: +# hostname: http://localhost +# port: 8080 chunkers: test_chunker: type: sentence @@ -11,8 +16,9 @@ chunkers: port: 8085 detectors: test_detector: + type: text_contents service: - hostname: https://localhost/api/v1/text/contents + hostname: https://localhost port: 8000 chunker_id: test_chunker - default_threshold: 0.5 \ No newline at end of file + default_threshold: 0.5 diff --git a/src/clients.rs b/src/clients.rs index 4e40ae34..dc87d5e6 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -16,29 +16,35 @@ */ #![allow(dead_code)] -// Import error for adding `source` trait -use std::{collections::HashMap, error::Error as _, fmt::Display, pin::Pin, time::Duration}; +use std::{ + any::TypeId, + collections::{hash_map, HashMap}, + pin::Pin, + time::Duration, +}; -use futures::{future::join_all, Stream}; +use async_trait::async_trait; +use futures::Stream; use ginepro::LoadBalancedChannel; -use reqwest::{Response, StatusCode}; use tokio::{fs::File, io::AsyncReadExt}; -use tracing::error; use url::Url; use crate::{ config::{ServiceConfig, Tls}, - health::{HealthCheck, HealthCheckResult, HealthStatus, OptionalHealthCheckResponseBody}, + health::HealthCheckResult, }; +pub mod errors; +pub use errors::{ClientCode, Error}; + +pub mod http; +pub use http::HttpClient; + pub mod chunker; pub use chunker::ChunkerClient; pub mod detector; -pub use detector::DetectorClient; - -pub mod generation; -pub use generation::GenerationClient; +pub use detector::TextContentsDetectorClient; pub mod tgis; pub use tgis::TgisClient; @@ -46,358 +52,266 @@ pub use tgis::TgisClient; pub mod nlp; pub use nlp::NlpClient; -pub const DEFAULT_TGIS_PORT: u16 = 8033; -pub const DEFAULT_CAIKIT_NLP_PORT: u16 = 8085; -pub const DEFAULT_CHUNKER_PORT: u16 = 8085; -pub const DEFAULT_DETECTOR_PORT: u16 = 8080; -pub const COMMON_ROUTER_KEY: &str = "common-router"; +pub mod generation; +pub use generation::GenerationClient; + +pub mod openai; + const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(60); const DEFAULT_REQUEST_TIMEOUT_SEC: u64 = 600; pub type BoxStream = Pin + Send>>; -/// Client errors. -#[derive(Debug, Clone, PartialEq, thiserror::Error)] -pub enum Error { - #[error("{}", .message)] - Grpc { code: StatusCode, message: String }, - #[error("{}", .message)] - Http { code: StatusCode, message: String }, - #[error("model not found: {model_id}")] - ModelNotFound { model_id: String }, +mod private { + pub struct Seal; } -impl Error { - /// Returns status code. - pub fn status_code(&self) -> StatusCode { - match self { - // Return equivalent http status code for grpc status code - Error::Grpc { code, .. } => *code, - // Return http status code for error responses - // and 500 for other errors - Error::Http { code, .. } => *code, - // Return 404 for model not found - Error::ModelNotFound { .. } => StatusCode::NOT_FOUND, - } +#[async_trait] +pub trait Client: Send + Sync + 'static { + /// Returns the name of the client type. + fn name(&self) -> &str; + + /// Returns the `TypeId` of the client type. Sealed to prevent overrides. + fn type_id(&self, _: private::Seal) -> TypeId { + TypeId::of::() } + + /// Performs a client health check. + async fn health(&self) -> HealthCheckResult; } -impl From for Error { - fn from(value: reqwest::Error) -> Self { - // Log lower level source of error. - // Examples: - // 1. client error (Connect) // Cases like connection error, wrong port etc. - // 2. client error (SendRequest) // Cases like cert issues - error!( - "http request failed. Source: {}", - value.source().unwrap().to_string() - ); - // Return http status code for error responses - // and 500 for other errors - let code = match value.status() { - Some(code) => code, - None => StatusCode::INTERNAL_SERVER_ERROR, - }; - Self::Http { - code, - message: value.to_string(), - } +impl dyn Client { + pub fn is(&self) -> bool { + TypeId::of::() == self.type_id(private::Seal) } -} -impl From for Error { - fn from(value: tonic::Status) -> Self { - use tonic::Code::*; - // Return equivalent http status code for grpc status code - let code = match value.code() { - InvalidArgument => StatusCode::BAD_REQUEST, - Internal => StatusCode::INTERNAL_SERVER_ERROR, - NotFound => StatusCode::NOT_FOUND, - DeadlineExceeded => StatusCode::REQUEST_TIMEOUT, - Unimplemented => StatusCode::NOT_IMPLEMENTED, - Unauthenticated => StatusCode::UNAUTHORIZED, - PermissionDenied => StatusCode::FORBIDDEN, - Unavailable => StatusCode::SERVICE_UNAVAILABLE, - Ok => StatusCode::OK, - _ => StatusCode::INTERNAL_SERVER_ERROR, - }; - Self::Grpc { - code, - message: value.message().to_string(), + pub fn downcast(self: Box) -> Result, Box> { + if (*self).is::() { + let ptr = Box::into_raw(self) as *mut T; + // SAFETY: guaranteed by `is` + unsafe { Ok(Box::from_raw(ptr)) } + } else { + Err(self) } } -} -#[derive(Debug, Clone, PartialEq)] -pub enum ClientCode { - Http(StatusCode), - Grpc(tonic::Code), -} + pub fn downcast_ref(&self) -> Option<&T> { + if (*self).is::() { + let ptr = self as *const dyn Client as *const T; + // SAFETY: guaranteed by `is` + unsafe { Some(&*ptr) } + } else { + None + } + } -impl Display for ClientCode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ClientCode::Http(code) => write!(f, "HTTP {}", code), - ClientCode::Grpc(code) => write!(f, "gRPC {:?} {}", code, code), + pub fn downcast_mut(&mut self) -> Option<&mut T> { + if (*self).is::() { + let ptr = self as *mut dyn Client as *mut T; + // SAFETY: guaranteed by `is` + unsafe { Some(&mut *ptr) } + } else { + None } } } -#[derive(Clone)] -pub struct HttpClient { - base_url: Url, - health_url: Url, - client: reqwest::Client, -} +/// A map containing different types of clients. +#[derive(Default)] +pub struct ClientMap(HashMap>); -impl HttpClient { - pub fn new(base_url: Url, client: reqwest::Client) -> Self { - let health_url = extract_base_url(&base_url).join("health").unwrap(); - Self { - base_url, - health_url, - client, - } +impl ClientMap { + /// Creates an empty `ClientMap`. + #[inline] + pub fn new() -> Self { + Self(HashMap::new()) } - pub fn base_url(&self) -> &Url { - &self.base_url + /// Inserts a client into the map. + #[inline] + pub fn insert(&mut self, key: String, value: V) { + self.0.insert(key, Box::new(value)); } - /// This is sectioned off to allow for testing. - pub(super) async fn http_response_to_health_check_result( - res: Result, - ) -> HealthCheckResult { - match res { - Ok(response) => { - if response.status() == StatusCode::OK { - if let Ok(body) = response.json::().await { - // If the service provided a body, we only anticipate a minimal health status and optional reason. - HealthCheckResult { - health_status: body.health_status.clone(), - response_code: ClientCode::Http(StatusCode::OK), - reason: match body.health_status { - HealthStatus::Healthy => None, - _ => body.reason, - }, - } - } else { - // If the service did not provide a body, we assume it is healthy. - HealthCheckResult { - health_status: HealthStatus::Healthy, - response_code: ClientCode::Http(StatusCode::OK), - reason: None, - } - } - } else { - HealthCheckResult { - // The most we can presume is that 5xx errors are likely indicating service issues, implying the service is unhealthy. - // and that 4xx errors are more likely indicating health check failures, i.e. due to configuration/implementation issues. - // Regardless we can't be certain, so the reason is also provided. - // TODO: We will likely circle back to re-evaluate this logic in the future - // when we know more about how the client health results will be used. - health_status: if response.status().as_u16() >= 500 - && response.status().as_u16() < 600 - { - HealthStatus::Unhealthy - } else if response.status().as_u16() >= 400 - && response.status().as_u16() < 500 - { - HealthStatus::Unknown - } else { - error!( - "unexpected http health check status code: {}", - response.status() - ); - HealthStatus::Unknown - }, - response_code: ClientCode::Http(response.status()), - reason: Some(format!( - "{}{}", - response.error_for_status_ref().unwrap_err(), - response - .text() - .await - .map(|s| if s.is_empty() { - "".to_string() - } else { - format!(": {}", s) - }) - .unwrap_or("".to_string()) - )), - } - } - } - Err(e) => { - error!("error checking health: {}", e); - HealthCheckResult { - health_status: HealthStatus::Unknown, - response_code: ClientCode::Http( - e.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), - ), - reason: Some(e.to_string()), - } - } - } + /// Returns a reference to the client trait object. + #[inline] + pub fn get(&self, key: &str) -> Option<&dyn Client> { + self.0.get(key).map(|v| v.as_ref()) } -} -impl HealthCheck for HttpClient { - async fn check(&self) -> HealthCheckResult { - let res = self.get(self.health_url.clone()).send().await; - Self::http_response_to_health_check_result(res).await + /// Returns a mutable reference to the client trait object. + #[inline] + pub fn get_mut(&mut self, key: &str) -> Option<&mut dyn Client> { + self.0.get_mut(key).map(|v| v.as_mut()) + } + + /// Downcasts and returns a reference to the concrete client type. + #[inline] + pub fn get_as(&self, key: &str) -> Option<&V> { + self.0.get(key)?.downcast_ref::() + } + + /// Downcasts and returns a mutable reference to the concrete client type. + #[inline] + pub fn get_mut_as(&mut self, key: &str) -> Option<&mut V> { + self.0.get_mut(key)?.downcast_mut::() + } + + /// Removes a client from the map. + #[inline] + pub fn remove(&mut self, key: &str) -> Option> { + self.0.remove(key) + } + + /// An iterator visiting all key-value pairs in arbitrary order. + #[inline] + pub fn iter(&self) -> hash_map::Iter<'_, String, Box> { + self.0.iter() + } + + /// An iterator visiting all keys in arbitrary order. + #[inline] + pub fn keys(&self) -> hash_map::Keys<'_, String, Box> { + self.0.keys() + } + + /// An iterator visiting all values in arbitrary order. + #[inline] + pub fn values(&self) -> hash_map::Values<'_, String, Box> { + self.0.values() } -} -impl std::ops::Deref for HttpClient { - type Target = reqwest::Client; + /// Returns the number of elements in the map. + #[inline] + pub fn len(&self) -> usize { + self.0.len() + } - fn deref(&self) -> &Self::Target { - &self.client + /// Returns `true` if the map contains no elements. + #[inline] + pub fn is_empty(&self) -> bool { + self.0.is_empty() } } -pub async fn create_http_clients( - default_port: u16, - config: &[(String, ServiceConfig)], -) -> HashMap { - let clients = config - .iter() - .map(|(name, service_config)| async move { - let port = service_config.port.unwrap_or(default_port); - let mut base_url = Url::parse(&service_config.hostname).unwrap(); - base_url.set_port(Some(port)).unwrap(); - let request_timeout = Duration::from_secs( - service_config - .request_timeout - .unwrap_or(DEFAULT_REQUEST_TIMEOUT_SEC), - ); - let mut builder = reqwest::ClientBuilder::new() - .connect_timeout(DEFAULT_CONNECT_TIMEOUT) - .timeout(request_timeout); - if let Some(Tls::Config(tls_config)) = &service_config.tls { - let mut cert_buf = Vec::new(); - let cert_path = tls_config.cert_path.as_ref().unwrap().as_path(); - File::open(cert_path) - .await - .unwrap_or_else(|error| { - panic!("error reading cert from {cert_path:?}: {error}") - }) - .read_to_end(&mut cert_buf) - .await - .unwrap(); - - if let Some(key_path) = &tls_config.key_path { - File::open(key_path) - .await - .unwrap_or_else(|error| { - panic!("error reading key from {key_path:?}: {error}") - }) - .read_to_end(&mut cert_buf) - .await - .unwrap(); - } - let identity = reqwest::Identity::from_pem(&cert_buf).unwrap_or_else(|error| { - panic!("error parsing bundled client certificate: {error}") - }); +pub async fn create_http_client(default_port: u16, service_config: &ServiceConfig) -> HttpClient { + let port = service_config.port.unwrap_or(default_port); + let mut base_url = Url::parse(&service_config.hostname).unwrap(); + base_url.set_port(Some(port)).unwrap(); + let request_timeout = Duration::from_secs( + service_config + .request_timeout + .unwrap_or(DEFAULT_REQUEST_TIMEOUT_SEC), + ); + let mut builder = reqwest::ClientBuilder::new() + .connect_timeout(DEFAULT_CONNECT_TIMEOUT) + .timeout(request_timeout); + if let Some(Tls::Config(tls_config)) = &service_config.tls { + let mut cert_buf = Vec::new(); + let cert_path = tls_config.cert_path.as_ref().unwrap().as_path(); + File::open(cert_path) + .await + .unwrap_or_else(|error| panic!("error reading cert from {cert_path:?}: {error}")) + .read_to_end(&mut cert_buf) + .await + .unwrap(); + + if let Some(key_path) = &tls_config.key_path { + File::open(key_path) + .await + .unwrap_or_else(|error| panic!("error reading key from {key_path:?}: {error}")) + .read_to_end(&mut cert_buf) + .await + .unwrap(); + } + let identity = reqwest::Identity::from_pem(&cert_buf) + .unwrap_or_else(|error| panic!("error parsing bundled client certificate: {error}")); - builder = builder.use_rustls_tls().identity(identity); - builder = builder.danger_accept_invalid_certs(tls_config.insecure.unwrap_or(false)); - - if let Some(client_ca_cert_path) = &tls_config.client_ca_cert_path { - let ca_cert = - tokio::fs::read(client_ca_cert_path) - .await - .unwrap_or_else(|error| { - panic!("error reading cert from {client_ca_cert_path:?}: {error}") - }); - let cacert = reqwest::Certificate::from_pem(&ca_cert) - .unwrap_or_else(|error| panic!("error parsing ca cert: {error}")); - builder = builder.add_root_certificate(cacert) - } - } - let client = builder - .build() - .unwrap_or_else(|error| panic!("error creating http client for {name}: {error}")); - let client = HttpClient::new(base_url, client); - (name.clone(), client) - }) - .collect::>(); - join_all(clients).await.into_iter().collect() + builder = builder.use_rustls_tls().identity(identity); + builder = builder.danger_accept_invalid_certs(tls_config.insecure.unwrap_or(false)); + + if let Some(client_ca_cert_path) = &tls_config.client_ca_cert_path { + let ca_cert = tokio::fs::read(client_ca_cert_path) + .await + .unwrap_or_else(|error| { + panic!("error reading cert from {client_ca_cert_path:?}: {error}") + }); + let cacert = reqwest::Certificate::from_pem(&ca_cert) + .unwrap_or_else(|error| panic!("error parsing ca cert: {error}")); + builder = builder.add_root_certificate(cacert) + } + } + let client = builder + .build() + .unwrap_or_else(|error| panic!("error creating http client: {error}")); + HttpClient::new(base_url, client) } -async fn create_grpc_clients( +pub async fn create_grpc_client( default_port: u16, - config: &[(String, ServiceConfig)], + service_config: &ServiceConfig, new: fn(LoadBalancedChannel) -> C, -) -> HashMap { - let clients = config - .iter() - .map(|(name, service_config)| async move { - let request_timeout = Duration::from_secs(service_config.request_timeout.unwrap_or(DEFAULT_REQUEST_TIMEOUT_SEC)); - let mut builder = LoadBalancedChannel::builder(( - service_config.hostname.clone(), - service_config.port.unwrap_or(default_port), - )) - .connect_timeout(DEFAULT_CONNECT_TIMEOUT) - .timeout(request_timeout); - - let client_tls_config = if let Some(Tls::Config(tls_config)) = &service_config.tls { - let cert_path = tls_config.cert_path.as_ref().unwrap().as_path(); - let key_path = tls_config.key_path.as_ref().unwrap().as_path(); - let cert_pem = tokio::fs::read(cert_path) +) -> C { + let request_timeout = Duration::from_secs( + service_config + .request_timeout + .unwrap_or(DEFAULT_REQUEST_TIMEOUT_SEC), + ); + let mut builder = LoadBalancedChannel::builder(( + service_config.hostname.clone(), + service_config.port.unwrap_or(default_port), + )) + .connect_timeout(DEFAULT_CONNECT_TIMEOUT) + .timeout(request_timeout); + + let client_tls_config = if let Some(Tls::Config(tls_config)) = &service_config.tls { + let cert_path = tls_config.cert_path.as_ref().unwrap().as_path(); + let key_path = tls_config.key_path.as_ref().unwrap().as_path(); + let cert_pem = tokio::fs::read(cert_path) + .await + .unwrap_or_else(|error| panic!("error reading cert from {cert_path:?}: {error}")); + let key_pem = tokio::fs::read(key_path) + .await + .unwrap_or_else(|error| panic!("error reading key from {key_path:?}: {error}")); + let identity = tonic::transport::Identity::from_pem(cert_pem, key_pem); + let mut client_tls_config = tonic::transport::ClientTlsConfig::new() + .identity(identity) + .with_native_roots() + .with_webpki_roots(); + if let Some(client_ca_cert_path) = &tls_config.client_ca_cert_path { + let client_ca_cert_pem = + tokio::fs::read(client_ca_cert_path) .await - .unwrap_or_else(|error| panic!("error reading cert from {cert_path:?}: {error}")); - let key_pem = tokio::fs::read(key_path) - .await - .unwrap_or_else(|error| panic!("error reading key from {key_path:?}: {error}")); - let identity = tonic::transport::Identity::from_pem(cert_pem, key_pem); - let mut client_tls_config = - tonic::transport::ClientTlsConfig::new().identity(identity).with_native_roots().with_webpki_roots(); - if let Some(client_ca_cert_path) = &tls_config.client_ca_cert_path { - let client_ca_cert_pem = tokio::fs::read(client_ca_cert_path) - .await - .unwrap_or_else(|error| { - panic!("error reading client ca cert from {client_ca_cert_path:?}: {error}") - }); - client_tls_config = client_tls_config.ca_certificate( - tonic::transport::Certificate::from_pem(client_ca_cert_pem), - ); - } - Some(client_tls_config) - } else { - None - }; - if let Some(client_tls_config) = client_tls_config { - builder = builder.with_tls(client_tls_config); - } - let channel = builder.channel().await.unwrap_or_else(|error| panic!("error creating grpc client for {name}: {error}")); - (name.clone(), new(channel)) - }) - .collect::>(); - join_all(clients).await.into_iter().collect() -} - -/// Extracts a base url from a url including path segments. -fn extract_base_url(url: &Url) -> Url { - let mut url = url.clone(); - match url.path_segments_mut() { - Ok(mut path) => { - path.clear(); - } - Err(_) => { - panic!("url cannot be a base"); + .unwrap_or_else(|error| { + panic!("error reading client ca cert from {client_ca_cert_path:?}: {error}") + }); + client_tls_config = client_tls_config + .ca_certificate(tonic::transport::Certificate::from_pem(client_ca_cert_pem)); } + Some(client_tls_config) + } else { + None + }; + if let Some(client_tls_config) = client_tls_config { + builder = builder.with_tls(client_tls_config); } - url + let channel = builder + .channel() + .await + .unwrap_or_else(|error| panic!("error creating grpc client: {error}")); + new(channel) } #[cfg(test)] mod tests { - use hyper::http; + use hyper::{http, StatusCode}; + use reqwest::Response; use super::*; - use crate::pb::grpc::health::v1::{health_check_response::ServingStatus, HealthCheckResponse}; + use crate::{ + health::{HealthCheckResult, HealthStatus}, + pb::grpc::health::v1::{health_check_response::ServingStatus, HealthCheckResponse}, + }; async fn mock_http_response( status: StatusCode, @@ -662,20 +576,4 @@ mod tests { ); } } - - #[test] - fn test_extract_base_url() { - let url = - Url::parse("https://example-detector.route.example.com/api/v1/text/contents").unwrap(); - let base_url = extract_base_url(&url); - assert_eq!( - Url::parse("https://example-detector.route.example.com/").unwrap(), - base_url - ); - let health_url = base_url.join("/health").unwrap(); - assert_eq!( - Url::parse("https://example-detector.route.example.com/health").unwrap(), - health_url - ); - } } diff --git a/src/clients/chunker.rs b/src/clients/chunker.rs index 35b04591..e1e097a3 100644 --- a/src/clients/chunker.rs +++ b/src/clients/chunker.rs @@ -15,19 +15,19 @@ */ -use std::{collections::HashMap, pin::Pin}; +use std::pin::Pin; +use async_trait::async_trait; use futures::{Future, Stream, StreamExt, TryStreamExt}; use ginepro::LoadBalancedChannel; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; -use tonic::{Request, Response, Status, Streaming}; +use tonic::{Code, Request, Response, Status, Streaming}; use tracing::info; -use super::{create_grpc_clients, BoxStream, Error}; +use super::{BoxStream, Client, ClientCode, Error}; use crate::{ - config::ServiceConfig, - health::{HealthCheckResult, HealthProbe}, + health::{HealthCheckResult, HealthStatus}, pb::{ caikit::runtime::chunkers::{ chunkers_service_client::ChunkersServiceClient, @@ -45,53 +45,25 @@ pub const DEFAULT_MODEL_ID: &str = "whole_doc_chunker"; type StreamingTokenizationResult = Result>, Status>; -#[cfg_attr(test, faux::create, derive(Default))] +#[cfg_attr(test, faux::create)] #[derive(Clone)] pub struct ChunkerClient { - clients: HashMap>, - health_clients: HashMap>, -} - -#[cfg_attr(test, faux::methods)] -impl HealthProbe for ChunkerClient { - async fn health(&self) -> Result, Error> { - let mut results = HashMap::with_capacity(self.health_clients.len()); - for (model_id, mut client) in self.health_clients.clone() { - results.insert( - model_id.clone(), - client - .check(HealthCheckRequest { - service: "".to_string(), - }) // Caikit does not expect a service_id to be specified - .await - .into(), - ); - } - Ok(results) - } + client: ChunkersServiceClient, + health_client: HealthClient, } #[cfg_attr(test, faux::methods)] impl ChunkerClient { - pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self { - let clients = create_grpc_clients(default_port, config, ChunkersServiceClient::new).await; - let health_clients = create_grpc_clients(default_port, config, HealthClient::new).await; + pub fn new( + client: ChunkersServiceClient, + health_client: HealthClient, + ) -> Self { Self { - clients, - health_clients, + client, + health_client, } } - fn client(&self, model_id: &str) -> Result, Error> { - Ok(self - .clients - .get(model_id) - .ok_or_else(|| Error::ModelNotFound { - model_id: model_id.to_string(), - })? - .clone()) - } - pub async fn tokenization_task_predict( &self, model_id: &str, @@ -102,9 +74,9 @@ impl ChunkerClient { info!("Using default whole doc chunker"); return Ok(tokenize_whole_doc(request)); } + let mut client = self.client.clone(); let request = request_with_model_id(request, model_id); - Ok(self - .client(model_id)? + Ok(client .chunker_tokenization_task_predict(request) .await? .into_inner()) @@ -126,7 +98,7 @@ impl ChunkerClient { }); ReceiverStream::new(response_rx).boxed() } else { - let mut client = self.client(model_id)?; + let mut client = self.client.clone(); let request = request_with_model_id(request_stream, model_id); // NOTE: this is an ugly workaround to avoid bogus higher-ranked lifetime errors. // https://github.com/rust-lang/rust/issues/110338 @@ -143,6 +115,38 @@ impl ChunkerClient { } } +#[cfg_attr(test, faux::methods)] +#[async_trait] +impl Client for ChunkerClient { + fn name(&self) -> &str { + "chunker" + } + + async fn health(&self) -> HealthCheckResult { + let mut client = self.health_client.clone(); + let response = client + .check(HealthCheckRequest { service: "".into() }) + .await; + let code = match response { + Ok(_) => Code::Ok, + Err(status) if matches!(status.code(), Code::InvalidArgument | Code::NotFound) => { + Code::Ok + } + Err(status) => status.code(), + }; + let health_status = if matches!(code, Code::Ok) { + HealthStatus::Healthy + } else { + HealthStatus::Unhealthy + }; + HealthCheckResult { + health_status, + response_code: ClientCode::Grpc(code), + reason: None, + } + } +} + fn request_with_model_id(request: T, model_id: &str) -> Request { let mut request = Request::new(request); request diff --git a/src/clients/detector.rs b/src/clients/detector.rs index 455612bc..c00c1dfc 100644 --- a/src/clients/detector.rs +++ b/src/clients/detector.rs @@ -15,234 +15,21 @@ */ -use std::collections::HashMap; - -use hyper::{HeaderMap, StatusCode}; -use serde::{Deserialize, Serialize}; - -use super::{create_http_clients, Error, HttpClient}; -use crate::{ - config::ServiceConfig, - health::{HealthCheck, HealthCheckResult, HealthProbe}, - models::{DetectionResult, DetectorParams}, -}; +pub mod text_contents; +pub use text_contents::*; +pub mod text_context_chat; +pub use text_context_chat::*; +pub mod text_context_doc; +pub use text_context_doc::*; +pub mod text_generation; +use hyper::StatusCode; +use serde::Deserialize; +pub use text_generation::*; + +use super::Error; const DETECTOR_ID_HEADER_NAME: &str = "detector-id"; -// For some reason the order matters here. #[cfg_attr(test, derive(Default), faux::create)] doesn't work. (rustc --explain E0560) -#[cfg_attr(test, faux::create, derive(Default))] -#[derive(Clone)] -pub struct DetectorClient { - clients: HashMap, -} - -#[cfg_attr(test, faux::methods)] -impl HealthProbe for DetectorClient { - async fn health(&self) -> Result, Error> { - let mut results = HashMap::with_capacity(self.clients.len()); - for (model_id, client) in self.clients() { - results.insert(model_id.to_string(), client.check().await); - } - Ok(results) - } -} - -#[cfg_attr(test, faux::methods)] -impl DetectorClient { - pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self { - let clients: HashMap = create_http_clients(default_port, config).await; - Self { clients } - } - - fn client(&self, model_id: &str) -> Result { - Ok(self - .clients - .get(model_id) - .ok_or_else(|| Error::ModelNotFound { - model_id: model_id.to_string(), - })? - .clone()) - } - - fn clients(&self) -> impl Iterator { - self.clients.iter() - } - - // TODO: Use generics here, since the only thing that changes in comparison to generation_detection() - // is the "request" parameter and return types? - /// Invokes detectors implemented with the `/api/v1/text/contents` endpoint - pub async fn text_contents( - &self, - model_id: &str, - request: ContentAnalysisRequest, - headers: HeaderMap, - ) -> Result>, Error> { - let client = self.client(model_id)?; - let url = client.base_url().as_str(); - let response = client - .post(url) - .headers(headers) - .header(DETECTOR_ID_HEADER_NAME, model_id) - .json(&request) - .send() - .await?; - if response.status() == StatusCode::OK { - Ok(response.json().await?) - } else { - let code = response.status().as_u16(); - let error = response - .json::() - .await - .unwrap_or(DetectorError { - code, - message: "".into(), - }); - Err(error.into()) - } - } - - /// Invokes detectors implemented with the `/api/v1/text/generation` endpoint - pub async fn text_generation( - &self, - model_id: &str, - request: GenerationDetectionRequest, - headers: HeaderMap, - ) -> Result, Error> { - let client = self.client(model_id)?; - let url = client.base_url().as_str(); - let response = client - .post(url) - .headers(headers) - .header(DETECTOR_ID_HEADER_NAME, model_id) - .json(&request) - .send() - .await?; - if response.status() == StatusCode::OK { - Ok(response.json().await?) - } else { - let code = response.status().as_u16(); - let error = response - .json::() - .await - .unwrap_or(DetectorError { - code, - message: "".into(), - }); - Err(error.into()) - } - } - - /// Invokes detectors implemented with the `/api/v1/text/context/doc` endpoint - pub async fn text_context_doc( - &self, - model_id: &str, - request: ContextDocsDetectionRequest, - headers: HeaderMap, - ) -> Result, Error> { - let client = self.client(model_id)?; - let url = client.base_url().as_str(); - let response = client - .post(url) - .headers(headers) - .header(DETECTOR_ID_HEADER_NAME, model_id) - .json(&request) - .send() - .await?; - if response.status() == StatusCode::OK { - Ok(response.json().await?) - } else { - let code = response.status().as_u16(); - let error = response - .json::() - .await - .unwrap_or(DetectorError { - code, - message: "".into(), - }); - Err(error.into()) - } - } -} - -/// Request for text content analysis -/// Results of this request will contain analysis / detection of each of the provided documents -/// in the order they are present in the `contents` object. -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct ContentAnalysisRequest { - /// Field allowing users to provide list of documents for analysis - pub contents: Vec, -} - -impl ContentAnalysisRequest { - pub fn new(contents: Vec) -> ContentAnalysisRequest { - ContentAnalysisRequest { contents } - } -} - -/// Evidence -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct Evidence { - /// Evidence name - pub name: String, - /// Optional, evidence value - #[serde(skip_serializing_if = "Option::is_none")] - pub value: Option, - /// Optional, score for evidence - #[serde(skip_serializing_if = "Option::is_none")] - pub score: Option, -} - -/// Evidence in response -#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] -pub struct EvidenceObj { - /// Evidence name - pub name: String, - /// Optional, evidence value - #[serde(skip_serializing_if = "Option::is_none")] - pub value: Option, - /// Optional, score for evidence - #[serde(skip_serializing_if = "Option::is_none")] - pub score: Option, - /// Optional, evidence on evidence value - // Evidence nesting should likely not go beyond this - #[serde(skip_serializing_if = "Option::is_none")] - pub evidence: Option>, -} - -/// Response of text content analysis endpoint -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct ContentAnalysisResponse { - /// Start index of detection - pub start: usize, - /// End index of detection - pub end: usize, - /// Text corresponding to detection - pub text: String, - /// Relevant detection class - pub detection: String, - /// Detection type or aggregate detection label - pub detection_type: String, - /// Score of detection - pub score: f64, - /// Optional, any applicable evidence for detection - #[serde(skip_serializing_if = "Option::is_none")] - pub evidence: Option>, -} - -impl From for crate::models::TokenClassificationResult { - fn from(value: ContentAnalysisResponse) -> Self { - Self { - start: value.start as u32, - end: value.end as u32, - word: value.text, - entity: value.detection, - entity_group: value.detection_type, - score: value.score, - token_count: None, - } - } -} - #[derive(Debug, Clone, Deserialize)] pub struct DetectorError { pub code: u16, @@ -257,67 +44,3 @@ impl From for Error { } } } - -/// A struct representing a request to a detector compatible with the -/// /api/v1/text/generation endpoint. -#[cfg_attr(test, derive(PartialEq))] -#[derive(Debug, Serialize)] -pub struct GenerationDetectionRequest { - /// User prompt sent to LLM - pub prompt: String, - - /// Text generated from an LLM - pub generated_text: String, -} - -impl GenerationDetectionRequest { - pub fn new(prompt: String, generated_text: String) -> Self { - Self { - prompt, - generated_text, - } - } -} - -/// A struct representing a request to a detector compatible with the -/// /api/v1/text/context/doc endpoint. -#[cfg_attr(test, derive(PartialEq))] -#[derive(Debug, Serialize)] -pub struct ContextDocsDetectionRequest { - /// Content to run detection on - pub content: String, - - /// Type of context being sent - pub context_type: ContextType, - - /// Context to run detection on - pub context: Vec, - - // Detector Params - pub detector_params: DetectorParams, -} - -/// Enum representing the context type of a detection -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub enum ContextType { - #[serde(rename = "docs")] - Document, - #[serde(rename = "url")] - Url, -} - -impl ContextDocsDetectionRequest { - pub fn new( - content: String, - context_type: ContextType, - context: Vec, - detector_params: DetectorParams, - ) -> Self { - Self { - content, - context_type, - context, - detector_params, - } - } -} diff --git a/src/clients/detector/text_contents.rs b/src/clients/detector/text_contents.rs new file mode 100644 index 00000000..13b85c25 --- /dev/null +++ b/src/clients/detector/text_contents.rs @@ -0,0 +1,145 @@ +use async_trait::async_trait; +use hyper::StatusCode; +use serde::{Deserialize, Serialize}; + +use super::{DetectorError, DETECTOR_ID_HEADER_NAME}; +use crate::{ + clients::{Client, Error, HttpClient}, + health::HealthCheckResult, +}; + +#[cfg_attr(test, faux::create)] +#[derive(Clone)] +pub struct TextContentsDetectorClient { + client: HttpClient, +} + +#[cfg_attr(test, faux::methods)] +impl TextContentsDetectorClient { + pub fn new(client: HttpClient) -> Self { + Self { client } + } + + pub async fn text_contents( + &self, + model_id: &str, + request: ContentAnalysisRequest, + ) -> Result>, Error> { + let url = self + .client + .base_url() + .join("/api/v1/text/contents") + .unwrap(); + let response = self + .client + .post(url) + .header(DETECTOR_ID_HEADER_NAME, model_id) + .json(&request) + .send() + .await?; + if response.status() == StatusCode::OK { + Ok(response.json().await?) + } else { + let code = response.status().as_u16(); + let error = response + .json::() + .await + .unwrap_or(DetectorError { + code, + message: "".into(), + }); + Err(error.into()) + } + } +} + +#[cfg_attr(test, faux::methods)] +#[async_trait] +impl Client for TextContentsDetectorClient { + fn name(&self) -> &str { + "text_contents_detector" + } + + async fn health(&self) -> HealthCheckResult { + self.client.health().await + } +} + +/// Request for text content analysis +/// Results of this request will contain analysis / detection of each of the provided documents +/// in the order they are present in the `contents` object. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct ContentAnalysisRequest { + /// Field allowing users to provide list of documents for analysis + pub contents: Vec, +} + +impl ContentAnalysisRequest { + pub fn new(contents: Vec) -> ContentAnalysisRequest { + ContentAnalysisRequest { contents } + } +} + +/// Response of text content analysis endpoint +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct ContentAnalysisResponse { + /// Start index of detection + pub start: usize, + /// End index of detection + pub end: usize, + /// Text corresponding to detection + pub text: String, + /// Relevant detection class + pub detection: String, + /// Detection type or aggregate detection label + pub detection_type: String, + /// Score of detection + pub score: f64, + /// Optional, any applicable evidence for detection + #[serde(skip_serializing_if = "Option::is_none")] + pub evidence: Option>, +} + +impl From for crate::models::TokenClassificationResult { + fn from(value: ContentAnalysisResponse) -> Self { + Self { + start: value.start as u32, + end: value.end as u32, + word: value.text, + entity: value.detection, + entity_group: value.detection_type, + score: value.score, + token_count: None, + } + } +} + +/// Evidence +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct Evidence { + /// Evidence name + pub name: String, + /// Optional, evidence value + #[serde(skip_serializing_if = "Option::is_none")] + pub value: Option, + /// Optional, score for evidence + #[serde(skip_serializing_if = "Option::is_none")] + pub score: Option, +} + +/// Evidence in response +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub struct EvidenceObj { + /// Evidence name + pub name: String, + /// Optional, evidence value + #[serde(skip_serializing_if = "Option::is_none")] + pub value: Option, + /// Optional, score for evidence + #[serde(skip_serializing_if = "Option::is_none")] + pub score: Option, + /// Optional, evidence on evidence value + // Evidence nesting should likely not go beyond this + #[serde(skip_serializing_if = "Option::is_none")] + pub evidence: Option>, +} diff --git a/src/clients/detector/text_context_chat.rs b/src/clients/detector/text_context_chat.rs new file mode 100644 index 00000000..81e06118 --- /dev/null +++ b/src/clients/detector/text_context_chat.rs @@ -0,0 +1,40 @@ +use async_trait::async_trait; + +use crate::{ + clients::{Client, HttpClient}, + health::HealthCheckResult, +}; + +#[cfg_attr(test, faux::create)] +#[derive(Clone)] +pub struct TextContextChatDetectorClient { + client: HttpClient, +} + +#[cfg_attr(test, faux::methods)] +impl TextContextChatDetectorClient { + pub fn new(client: HttpClient) -> Self { + Self { client } + } + + pub async fn text_context_chat(&self) { + let _url = self + .client + .base_url() + .join("/api/v1/text/context/chat") + .unwrap(); + todo!() + } +} + +#[cfg_attr(test, faux::methods)] +#[async_trait] +impl Client for TextContextChatDetectorClient { + fn name(&self) -> &str { + "text_context_chat_detector" + } + + async fn health(&self) -> HealthCheckResult { + self.client.health().await + } +} diff --git a/src/clients/detector/text_context_doc.rs b/src/clients/detector/text_context_doc.rs new file mode 100644 index 00000000..3858c4cc --- /dev/null +++ b/src/clients/detector/text_context_doc.rs @@ -0,0 +1,110 @@ +use async_trait::async_trait; +use hyper::StatusCode; +use serde::{Deserialize, Serialize}; + +use super::{DetectorError, DETECTOR_ID_HEADER_NAME}; +use crate::{ + clients::{Client, Error, HttpClient}, + health::HealthCheckResult, + models::{DetectionResult, DetectorParams}, +}; + +#[cfg_attr(test, faux::create)] +#[derive(Clone)] +pub struct TextContextDocDetectorClient { + client: HttpClient, +} + +#[cfg_attr(test, faux::methods)] +impl TextContextDocDetectorClient { + pub fn new(client: HttpClient) -> Self { + Self { client } + } + + pub async fn text_context_docs( + &self, + model_id: &str, + request: ContextDocsDetectionRequest, + ) -> Result, Error> { + let url = self + .client + .base_url() + .join("/api/v1/text/context/doc") + .unwrap(); + let response = self + .client + .post(url) + .header(DETECTOR_ID_HEADER_NAME, model_id) + .json(&request) + .send() + .await?; + if response.status() == StatusCode::OK { + Ok(response.json().await?) + } else { + let code = response.status().as_u16(); + let error = response + .json::() + .await + .unwrap_or(DetectorError { + code, + message: "".into(), + }); + Err(error.into()) + } + } +} + +#[cfg_attr(test, faux::methods)] +#[async_trait] +impl Client for TextContextDocDetectorClient { + fn name(&self) -> &str { + "text_context_doc_detector" + } + + async fn health(&self) -> HealthCheckResult { + self.client.health().await + } +} + +/// A struct representing a request to a detector compatible with the +/// /api/v1/text/context/doc endpoint. +#[cfg_attr(test, derive(PartialEq))] +#[derive(Debug, Serialize)] +pub struct ContextDocsDetectionRequest { + /// Content to run detection on + pub content: String, + + /// Type of context being sent + pub context_type: ContextType, + + /// Context to run detection on + pub context: Vec, + + // Detector Params + pub detector_params: DetectorParams, +} + +/// Enum representing the context type of a detection +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub enum ContextType { + #[serde(rename = "docs")] + Document, + #[serde(rename = "url")] + Url, +} + +impl ContextDocsDetectionRequest { + pub fn new( + content: String, + context_type: ContextType, + context: Vec, + detector_params: DetectorParams, + ) -> Self { + Self { + content, + context_type, + context, + detector_params, + } + } +} diff --git a/src/clients/detector/text_generation.rs b/src/clients/detector/text_generation.rs new file mode 100644 index 00000000..7b55893d --- /dev/null +++ b/src/clients/detector/text_generation.rs @@ -0,0 +1,88 @@ +use async_trait::async_trait; +use hyper::StatusCode; +use serde::Serialize; + +use super::{DetectorError, DETECTOR_ID_HEADER_NAME}; +use crate::{ + clients::{Client, Error, HttpClient}, + health::HealthCheckResult, + models::DetectionResult, +}; + +#[cfg_attr(test, faux::create)] +#[derive(Clone)] +pub struct TextGenerationDetectorClient { + client: HttpClient, +} + +#[cfg_attr(test, faux::methods)] +impl TextGenerationDetectorClient { + pub fn new(client: HttpClient) -> Self { + Self { client } + } + + pub async fn text_generation( + &self, + model_id: &str, + request: GenerationDetectionRequest, + ) -> Result, Error> { + let url = self + .client + .base_url() + .join("/api/v1/text/generation") + .unwrap(); + let response = self + .client + .post(url) + .header(DETECTOR_ID_HEADER_NAME, model_id) + .json(&request) + .send() + .await?; + if response.status() == StatusCode::OK { + Ok(response.json().await?) + } else { + let code = response.status().as_u16(); + let error = response + .json::() + .await + .unwrap_or(DetectorError { + code, + message: "".into(), + }); + Err(error.into()) + } + } +} + +#[cfg_attr(test, faux::methods)] +#[async_trait] +impl Client for TextGenerationDetectorClient { + fn name(&self) -> &str { + "text_context_doc_detector" + } + + async fn health(&self) -> HealthCheckResult { + self.client.health().await + } +} + +/// A struct representing a request to a detector compatible with the +/// /api/v1/text/generation endpoint. +#[cfg_attr(test, derive(PartialEq))] +#[derive(Debug, Serialize)] +pub struct GenerationDetectionRequest { + /// User prompt sent to LLM + pub prompt: String, + + /// Text generated from an LLM + pub generated_text: String, +} + +impl GenerationDetectionRequest { + pub fn new(prompt: String, generated_text: String) -> Self { + Self { + prompt, + generated_text, + } + } +} diff --git a/src/clients/errors.rs b/src/clients/errors.rs new file mode 100644 index 00000000..306d0e78 --- /dev/null +++ b/src/clients/errors.rs @@ -0,0 +1,91 @@ +use std::error::Error as _; + +use hyper::StatusCode; +use tracing::error; + +/// Client errors. +#[derive(Debug, Clone, PartialEq, thiserror::Error)] +pub enum Error { + #[error("{}", .message)] + Grpc { code: StatusCode, message: String }, + #[error("{}", .message)] + Http { code: StatusCode, message: String }, + #[error("model not found: {model_id}")] + ModelNotFound { model_id: String }, +} + +impl Error { + /// Returns status code. + pub fn status_code(&self) -> StatusCode { + match self { + // Return equivalent http status code for grpc status code + Error::Grpc { code, .. } => *code, + // Return http status code for error responses + // and 500 for other errors + Error::Http { code, .. } => *code, + // Return 404 for model not found + Error::ModelNotFound { .. } => StatusCode::NOT_FOUND, + } + } +} + +impl From for Error { + fn from(value: reqwest::Error) -> Self { + // Log lower level source of error. + // Examples: + // 1. client error (Connect) // Cases like connection error, wrong port etc. + // 2. client error (SendRequest) // Cases like cert issues + error!( + "http request failed. Source: {}", + value.source().unwrap().to_string() + ); + // Return http status code for error responses + // and 500 for other errors + let code = match value.status() { + Some(code) => code, + None => StatusCode::INTERNAL_SERVER_ERROR, + }; + Self::Http { + code, + message: value.to_string(), + } + } +} + +impl From for Error { + fn from(value: tonic::Status) -> Self { + use tonic::Code::*; + // Return equivalent http status code for grpc status code + let code = match value.code() { + InvalidArgument => StatusCode::BAD_REQUEST, + Internal => StatusCode::INTERNAL_SERVER_ERROR, + NotFound => StatusCode::NOT_FOUND, + DeadlineExceeded => StatusCode::REQUEST_TIMEOUT, + Unimplemented => StatusCode::NOT_IMPLEMENTED, + Unauthenticated => StatusCode::UNAUTHORIZED, + PermissionDenied => StatusCode::FORBIDDEN, + Unavailable => StatusCode::SERVICE_UNAVAILABLE, + Ok => StatusCode::OK, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }; + Self::Grpc { + code, + message: value.message().to_string(), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum ClientCode { + Http(StatusCode), + Grpc(tonic::Code), +} + +impl std::fmt::Display for ClientCode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ClientCode::Http(code) => write!(f, "HTTP {}", code), + ClientCode::Grpc(code) => write!(f, "gRPC {:?} {}", code, code), + } + } +} diff --git a/src/clients/generation.rs b/src/clients/generation.rs index d599d8c1..c10520dd 100644 --- a/src/clients/generation.rs +++ b/src/clients/generation.rs @@ -15,15 +15,14 @@ */ -use std::collections::HashMap; - +use async_trait::async_trait; use futures::{StreamExt, TryStreamExt}; use hyper::HeaderMap; use tracing::debug; -use super::{BoxStream, Error, NlpClient, TgisClient}; +use super::{BoxStream, Client, Error, NlpClient, TgisClient}; use crate::{ - health::{HealthCheckResult, HealthProbe}, + health::HealthCheckResult, models::{ ClassifiedGeneratedTextResult, ClassifiedGeneratedTextStreamResult, GuardrailsTextGenerationParameters, @@ -40,7 +39,7 @@ use crate::{ }, }; -#[cfg_attr(test, faux::create, derive(Default))] +#[cfg_attr(test, faux::create)] #[derive(Clone)] pub struct GenerationClient(Option); @@ -50,24 +49,6 @@ enum GenerationClientInner { Nlp(NlpClient), } -#[cfg_attr(test, faux::methods)] -impl HealthProbe for GenerationClient { - async fn health(&self) -> Result, Error> { - match &self.0 { - Some(GenerationClientInner::Tgis(client)) => client.health().await, - Some(GenerationClientInner::Nlp(client)) => client.health().await, - None => Ok(HashMap::new()), - } - } -} - -#[cfg(test)] -impl Default for GenerationClientInner { - fn default() -> Self { - Self::Tgis(TgisClient::default()) - } -} - #[cfg_attr(test, faux::methods)] impl GenerationClient { pub fn tgis(client: TgisClient) -> Self { @@ -253,3 +234,19 @@ impl GenerationClient { } } } + +#[cfg_attr(test, faux::methods)] +#[async_trait] +impl Client for GenerationClient { + fn name(&self) -> &str { + "generation" + } + + async fn health(&self) -> HealthCheckResult { + match &self.0 { + Some(GenerationClientInner::Tgis(client)) => client.health().await, + Some(GenerationClientInner::Nlp(client)) => client.health().await, + None => unimplemented!(), + } + } +} diff --git a/src/clients/http.rs b/src/clients/http.rs new file mode 100644 index 00000000..93c99b1d --- /dev/null +++ b/src/clients/http.rs @@ -0,0 +1,182 @@ +use hyper::StatusCode; +use reqwest::Response; +use tracing::error; +use url::Url; + +use super::ClientCode; +use crate::health::{HealthCheckResult, HealthStatus, OptionalHealthCheckResponseBody}; + +#[derive(Clone)] +pub struct HttpClient { + base_url: Url, + health_url: Url, + client: reqwest::Client, +} + +impl HttpClient { + pub fn new(base_url: Url, client: reqwest::Client) -> Self { + let health_url = base_url.join("health").unwrap(); + Self { + base_url, + health_url, + client, + } + } + + pub fn base_url(&self) -> &Url { + &self.base_url + } + + /// This is sectioned off to allow for testing. + pub(super) async fn http_response_to_health_check_result( + res: Result, + ) -> HealthCheckResult { + match res { + Ok(response) => { + if response.status() == StatusCode::OK { + if let Ok(body) = response.json::().await { + // If the service provided a body, we only anticipate a minimal health status and optional reason. + HealthCheckResult { + health_status: body.health_status.clone(), + response_code: ClientCode::Http(StatusCode::OK), + reason: match body.health_status { + HealthStatus::Healthy => None, + _ => body.reason, + }, + } + } else { + // If the service did not provide a body, we assume it is healthy. + HealthCheckResult { + health_status: HealthStatus::Healthy, + response_code: ClientCode::Http(StatusCode::OK), + reason: None, + } + } + } else { + HealthCheckResult { + // The most we can presume is that 5xx errors are likely indicating service issues, implying the service is unhealthy. + // and that 4xx errors are more likely indicating health check failures, i.e. due to configuration/implementation issues. + // Regardless we can't be certain, so the reason is also provided. + // TODO: We will likely circle back to re-evaluate this logic in the future + // when we know more about how the client health results will be used. + health_status: if response.status().as_u16() >= 500 + && response.status().as_u16() < 600 + { + HealthStatus::Unhealthy + } else if response.status().as_u16() >= 400 + && response.status().as_u16() < 500 + { + HealthStatus::Unknown + } else { + error!( + "unexpected http health check status code: {}", + response.status() + ); + HealthStatus::Unknown + }, + response_code: ClientCode::Http(response.status()), + reason: Some(format!( + "{}{}", + response.error_for_status_ref().unwrap_err(), + response + .text() + .await + .map(|s| if s.is_empty() { + "".to_string() + } else { + format!(": {}", s) + }) + .unwrap_or("".to_string()) + )), + } + } + } + Err(e) => { + error!("error checking health: {}", e); + HealthCheckResult { + health_status: HealthStatus::Unknown, + response_code: ClientCode::Http( + e.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), + ), + reason: Some(e.to_string()), + } + } + } + } + + pub async fn health(&self) -> HealthCheckResult { + let res = self.get(self.health_url.clone()).send().await; + Self::http_response_to_health_check_result(res).await + } +} + +impl std::ops::Deref for HttpClient { + type Target = reqwest::Client; + + fn deref(&self) -> &Self::Target { + &self.client + } +} + +/// Extracts a base url from a url including path segments. +pub fn extract_base_url(url: &Url) -> Option { + let mut url = url.clone(); + match url.path_segments_mut() { + Ok(mut path) => { + path.clear(); + } + Err(_) => { + return None; + } + } + Some(url) +} + +/// Returns `true` if url is a valid base url. +pub fn is_base_url(url: &str) -> bool { + if let Ok(url) = Url::parse(url) { + if let Some(base_url) = extract_base_url(&url) { + return url == base_url; + } + } + false +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_base_url() { + let url = + Url::parse("https://example-detector.route.example.com/api/v1/text/contents").unwrap(); + let base_url = extract_base_url(&url); + assert_eq!( + Some(Url::parse("https://example-detector.route.example.com/").unwrap()), + base_url + ); + let health_url = base_url.map(|v| v.join("/health").unwrap()); + assert_eq!( + Some(Url::parse("https://example-detector.route.example.com/health").unwrap()), + health_url + ); + } + + #[test] + fn test_is_base_url() { + let url = "http://localhost"; + assert!(is_base_url(url)); + + let url = "https://example-detector.route.example.com/"; + assert!(is_base_url(url)); + + let url = "https://example-detector.route.example.com"; + assert!(is_base_url(url)); + + let url = "https://example-detector.route.example.com/api/v1/text/contents"; + assert!(!is_base_url(url)); + + let url = "https://example-detector.route.example.com/api/v1/"; + assert!(!is_base_url(url)); + } +} diff --git a/src/clients/nlp.rs b/src/clients/nlp.rs index 90b3f693..7987b3dc 100644 --- a/src/clients/nlp.rs +++ b/src/clients/nlp.rs @@ -15,18 +15,15 @@ */ -use std::collections::HashMap; - +use async_trait::async_trait; use axum::http::{Extensions, HeaderMap}; use futures::{StreamExt, TryStreamExt}; use ginepro::LoadBalancedChannel; -use tonic::{metadata::MetadataMap, Request}; +use tonic::{metadata::MetadataMap, Code, Request}; -use super::{create_grpc_clients, BoxStream, Error}; +use super::{BoxStream, Client, ClientCode, Error}; use crate::{ - clients::COMMON_ROUTER_KEY, - config::ServiceConfig, - health::{HealthCheckResult, HealthProbe}, + health::{HealthCheckResult, HealthStatus}, pb::{ caikit::runtime::nlp::{ nlp_service_client::NlpServiceClient, ServerStreamingTextGenerationTaskRequest, @@ -42,64 +39,34 @@ use crate::{ const MODEL_ID_HEADER_NAME: &str = "mm-model-id"; -#[cfg_attr(test, faux::create, derive(Default))] +#[cfg_attr(test, faux::create)] #[derive(Clone)] pub struct NlpClient { - clients: HashMap>, - health_clients: HashMap>, -} - -#[cfg_attr(test, faux::methods)] -impl HealthProbe for NlpClient { - async fn health(&self) -> Result, Error> { - let mut results = HashMap::with_capacity(self.health_clients.len()); - for (model_id, mut client) in self.health_clients.clone() { - results.insert( - model_id.clone(), - client - .check(HealthCheckRequest { - service: model_id.clone(), - }) - .await - .into(), - ); - } - Ok(results) - } + client: NlpServiceClient, + health_client: HealthClient, } #[cfg_attr(test, faux::methods)] impl NlpClient { - pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self { - let clients = create_grpc_clients(default_port, config, NlpServiceClient::new).await; - let health_clients = create_grpc_clients(default_port, config, HealthClient::new).await; + pub fn new( + client: NlpServiceClient, + health_client: HealthClient, + ) -> Self { Self { - clients, - health_clients, + client, + health_client, } } - fn client(&self, _model_id: &str) -> Result, Error> { - // NOTE: We currently forward requests to common router, so we use a single client. - let model_id = COMMON_ROUTER_KEY; - Ok(self - .clients - .get(model_id) - .ok_or_else(|| Error::ModelNotFound { - model_id: model_id.to_string(), - })? - .clone()) - } - pub async fn tokenization_task_predict( &self, model_id: &str, request: TokenizationTaskRequest, headers: HeaderMap, ) -> Result { + let mut client = self.client.clone(); let request = request_with_model_id(request, model_id, headers); - Ok(self - .client(model_id)? + Ok(client .tokenization_task_predict(request) .await? .into_inner()) @@ -111,9 +78,9 @@ impl NlpClient { request: TokenClassificationTaskRequest, headers: HeaderMap, ) -> Result { + let mut client = self.client.clone(); let request = request_with_model_id(request, model_id, headers); - Ok(self - .client(model_id)? + Ok(client .token_classification_task_predict(request) .await? .into_inner()) @@ -125,9 +92,9 @@ impl NlpClient { request: TextGenerationTaskRequest, headers: HeaderMap, ) -> Result { + let mut client = self.client.clone(); let request = request_with_model_id(request, model_id, headers); - Ok(self - .client(model_id)? + Ok(client .text_generation_task_predict(request) .await? .into_inner()) @@ -139,9 +106,9 @@ impl NlpClient { request: ServerStreamingTextGenerationTaskRequest, headers: HeaderMap, ) -> Result>, Error> { + let mut client = self.client.clone(); let request = request_with_model_id(request, model_id, headers); - let response_stream = self - .client(model_id)? + let response_stream = client .server_streaming_text_generation_task_predict(request) .await? .into_inner() @@ -151,6 +118,38 @@ impl NlpClient { } } +#[cfg_attr(test, faux::methods)] +#[async_trait] +impl Client for NlpClient { + fn name(&self) -> &str { + "nlp" + } + + async fn health(&self) -> HealthCheckResult { + let mut client = self.health_client.clone(); + let response = client + .check(HealthCheckRequest { service: "".into() }) + .await; + let code = match response { + Ok(_) => Code::Ok, + Err(status) if matches!(status.code(), Code::InvalidArgument | Code::NotFound) => { + Code::Ok + } + Err(status) => status.code(), + }; + let health_status = if matches!(code, Code::Ok) { + HealthStatus::Healthy + } else { + HealthStatus::Unhealthy + }; + HealthCheckResult { + health_status, + response_code: ClientCode::Grpc(code), + reason: None, + } + } +} + fn request_with_model_id(request: T, model_id: &str, headers: HeaderMap) -> Request { let metadata = MetadataMap::from_headers(headers); let mut request = Request::from_parts(metadata, Extensions::new(), request); diff --git a/src/clients/openai.rs b/src/clients/openai.rs new file mode 100644 index 00000000..19bbba2d --- /dev/null +++ b/src/clients/openai.rs @@ -0,0 +1,423 @@ +use std::collections::HashMap; + +use async_trait::async_trait; +use hyper::StatusCode; +use indexmap::IndexMap; +use serde::{Deserialize, Serialize}; + +use super::{Client, Error, HttpClient}; +use crate::health::HealthCheckResult; + +#[cfg_attr(test, faux::create)] +#[derive(Clone)] +pub struct OpenAiClient { + client: HttpClient, +} + +#[cfg_attr(test, faux::methods)] +impl OpenAiClient { + pub fn new(client: HttpClient) -> Self { + Self { client } + } + + pub async fn chat_completions( + &self, + request: ChatCompletionRequest, + ) -> Result { + let url = self.client.base_url().join("/v1/chat/completions").unwrap(); + let response = self.client.post(url).json(&request).send().await?; + match response.status() { + StatusCode::OK => Ok(response.json().await?), + _ => Err(Error::Http { + code: response.status(), + message: "".into(), // TODO + }), + } + } + + pub async fn completions( + &self, + request: CompletionRequest, + ) -> Result { + let url = self.client.base_url().join("/v1/completions").unwrap(); + let response = self.client.post(url).json(&request).send().await?; + match response.status() { + StatusCode::OK => Ok(response.json().await?), + _ => Err(Error::Http { + code: response.status(), + message: "".into(), // TODO + }), + } + } +} + +#[cfg_attr(test, faux::methods)] +#[async_trait] +impl Client for OpenAiClient { + fn name(&self) -> &str { + "openai" + } + + async fn health(&self) -> HealthCheckResult { + self.client.health().await + } +} + +/// Usage statistics for a completion. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Usage { + /// Number of tokens in the generated completion. + pub completion_tokens: u32, + /// Number of tokens in the prompt. + pub prompt_tokens: u32, + /// Total number of tokens used in the request (prompt + completion). + pub total_tokens: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum StopTokens { + Array(Vec), + String(String), +} + +// Chat completions API types + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionRequest { + /// ID of the model to use. + pub model: String, + /// A list of messages comprising the conversation so far. + pub messages: Vec, + #[serde(default)] + pub frequency_penalty: Option, + /// Modify the likelihood of specified tokens appearing in the completion. + #[serde(default)] + pub logit_bias: Option>, + /// Whether to return log probabilities of the output tokens or not. + /// If true, returns the log probabilities of each output token returned in the content of message. + #[serde(default)] + pub logprobs: Option, + /// An integer between 0 and 20 specifying the number of most likely tokens to return + /// at each token position, each with an associated log probability. + /// logprobs must be set to true if this parameter is used. + #[serde(default)] + pub top_logprobs: Option, + /// The maximum number of tokens that can be generated in the chat completion. + #[serde(default)] + pub max_tokens: Option, + /// How many chat completion choices to generate for each input message. + #[serde(default)] + pub n: Option, + /// Positive values penalize new tokens based on whether they appear in the text so far, + /// increasing the model's likelihood to talk about new topics. + #[serde(default)] + pub presence_penalty: Option, + //#[serde(default)] + //pub response_format: Option, + /// If specified, our system will make a best effort to sample deterministically, + /// such that repeated requests with the same seed and parameters should return the same result. + #[serde(default)] + pub seed: Option, + /// Up to 4 sequences where the API will stop generating further tokens. + #[serde(default)] + pub stop: Option, + /// If set, partial message deltas will be sent, like in ChatGPT. + /// Tokens will be sent as data-only server-sent events as they become available, + /// with the stream terminated by a data: [DONE] message. + #[serde(default)] + pub stream: Option, + /// What sampling temperature to use, between 0 and 2. + /// Higher values like 0.8 will make the output more random, + /// while lower values like 0.2 will make it more focused and deterministic. + #[serde(default)] + pub temperature: Option, + /// An alternative to sampling with temperature, called nucleus sampling, + /// where the model considers the results of the tokens with top_p probability mass. + /// So 0.1 means only the tokens comprising the top 10% probability mass are considered. + #[serde(default)] + pub top_p: Option, + + // Additional vllm params + #[serde(default)] + pub best_of: Option, + #[serde(default)] + pub use_beam_search: Option, + #[serde(default)] + pub top_k: Option, + #[serde(default)] + pub min_p: Option, + #[serde(default)] + pub repetition_penalty: Option, + #[serde(default)] + pub length_penalty: Option, + #[serde(default)] + pub early_stopping: Option, + #[serde(default)] + pub ignore_eos: Option, + #[serde(default)] + pub min_tokens: Option, + #[serde(default)] + pub stop_token_ids: Option>, + #[serde(default)] + pub skip_special_tokens: Option, + #[serde(default)] + pub spaces_between_special_tokens: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub role: String, + pub content: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, +} + +impl Message { + pub fn new(role: &str, content: &str, name: Option<&str>) -> Self { + Self { + role: role.into(), + content: content.into(), + name: name.map(|s| s.into()), + } + } +} + +/// Represents a chat completion response returned by model, based on the provided input. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionResponse { + /// A unique identifier for the chat completion. + pub id: String, + /// The object type, which is always `chat.completion`. + pub object: String, + /// The Unix timestamp (in seconds) of when the chat completion was created. + pub created: i64, + /// The model used for the chat completion. + pub model: String, + /// This fingerprint represents the backend configuration that the model runs with. + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, + /// A list of chat completion choices. Can be more than one if n is greater than 1. + pub choices: Vec, + /// Usage statistics for the completion request. + pub usage: Usage, +} + +/// A chat completion choice. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionChoice { + /// The index of the choice in the list of choices. + pub index: usize, + /// A chat completion message generated by the model. + pub message: ChatCompletionMessage, + /// Log probability information for the choice. + pub logprobs: Option, + /// The reason the model stopped generating tokens. + pub finish_reason: String, +} + +/// A chat completion message generated by the model. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionMessage { + /// The contents of the message. + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + /// The role of the author of this message. + #[serde(skip_serializing_if = "Option::is_none")] + pub role: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatCompletionLogprobs { + #[serde(skip_serializing_if = "Vec::is_empty")] + pub content: Vec, +} + +/// Log probability information for a choice. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionLogprob { + /// The token. + pub token: String, + /// The log probability of this token. + pub logprob: f32, + /// List of the most likely tokens and their log probability, at this token position. + pub top_logprobs: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionTopLogprob { + /// The token. + pub token: String, + /// The log probability of this token. + pub logprob: f32, +} + +/// Represents a streamed chunk of a chat completion response returned by model, based on the provided input. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionChunk { + /// A unique identifier for the chat completion. Each chunk has the same ID. + pub id: String, + /// A list of chat completion choices. + pub choices: Vec, + /// The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same timestamp. + pub created: i64, + /// The model to generate the completion. + pub model: String, + /// The object type, which is always `chat.completion.chunk`. + pub object: &'static str, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionChunkChoice { + /// A chat completion delta generated by streamed model responses. + pub delta: ChatCompletionMessage, + /// The index of the choice in the list of choices. + pub index: u32, + /// Log probability information for the choice. + pub logprobs: Option, + /// The reason the model stopped generating tokens. + pub finish_reason: Option, +} + +// Completions API types + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CompletionRequest { + /// ID of the model to use. + pub model: String, + /// The prompt to generate completions for. + /// NOTE: Only supporting a single prompt for now. OpenAI supports a single string, + /// array of strings, array of tokens, or an array of token arrays. + pub prompt: String, + /// Generates best_of completions server-side and returns the "best" (the one with the highest log probability per token). + /// Results cannot be streamed. When used with n, best_of controls the number of candidate completions and n specifies + /// how many to return – best_of must be greater than n. + #[serde(default)] + pub best_of: Option, + /// Echo back the prompt in addition to the completion. + #[serde(default)] + pub echo: Option, + /// Positive values penalize new tokens based on their existing frequency in the text so far, + /// decreasing the model's likelihood to repeat the same line verbatim. + #[serde(default)] + pub frequency_penalty: Option, + /// Modify the likelihood of specified tokens appearing in the completion. + #[serde(default)] + pub logit_bias: Option>, + /// Include the log probabilities on the logprobs most likely output tokens, as well the chosen tokens. + #[serde(default)] + pub logprobs: Option, + /// The maximum number of tokens that can be generated in the completion. + #[serde(default)] + pub max_tokens: Option, + /// How many completions to generate for each prompt. + #[serde(default)] + pub n: Option, + /// Positive values penalize new tokens based on whether they appear in the text so far, + /// increasing the model's likelihood to talk about new topics. + #[serde(default)] + pub presence_penalty: Option, + /// If specified, our system will make a best effort to sample deterministically, + /// such that repeated requests with the same seed and parameters should return the same result. + #[serde(default)] + pub seed: Option, + /// Up to 4 sequences where the API will stop generating further tokens. + /// The returned text will not contain the stop sequence. + #[serde(default)] + pub stop: Option, + /// Whether to stream back partial progress. + /// If set, tokens will be sent as data-only server-sent events as they become available, + /// with the stream terminated by a data: [DONE] message. + #[serde(default)] + pub stream: Option, + #[serde(default)] + /// The suffix that comes after a completion of inserted text. + pub suffix: Option, + /// What sampling temperature to use, between 0 and 2. + /// Higher values like 0.8 will make the output more random, + /// while lower values like 0.2 will make it more focused and deterministic. + #[serde(default)] + pub temperature: Option, + /// An alternative to sampling with temperature, called nucleus sampling, + /// where the model considers the results of the tokens with top_p probability mass. + /// So 0.1 means only the tokens comprising the top 10% probability mass are considered. + #[serde(default)] + pub top_p: Option, + + // Additional vllm params + #[serde(default)] + pub use_beam_search: Option, + #[serde(default)] + pub top_k: Option, + #[serde(default)] + pub min_p: Option, + #[serde(default)] + pub repetition_penalty: Option, + #[serde(default)] + pub length_penalty: Option, + #[serde(default)] + pub early_stopping: Option, + #[serde(default)] + pub stop_token_ids: Option>, + #[serde(default)] + pub ignore_eos: Option, + #[serde(default)] + pub min_tokens: Option, + #[serde(default)] + pub skip_special_tokens: Option, + #[serde(default)] + pub spaces_between_special_tokens: Option, +} + +// #[derive(Debug, Clone, Serialize, Deserialize)] +// #[serde(untagged)] +// pub enum Prompt { +// Array(Vec), +// String(String), +// } + +/// Represents a completion response from the API. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CompletionResponse { + /// A unique identifier for the completion. + pub id: String, + /// The object type, which is always `text_completion`. + pub object: String, + /// The Unix timestamp (in seconds) of when the completion was created. + pub created: i64, + /// The model used for the completion. + pub model: String, + /// This fingerprint represents the backend configuration that the model runs with. + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, + /// A list of completion choices. Can be more than one if n is greater than 1. + pub choices: Vec, + /// Usage statistics for the completion request. + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, +} + +/// A completion choice. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CompletionChoice { + /// The index of the choice in the list of choices. + pub index: u32, + /// A chat completion message generated by the model. + pub text: Option, + /// Log probability information for the choice. + pub logprobs: Option, + /// The reason the model stopped generating tokens. + #[serde(skip_serializing_if = "Option::is_none")] + pub finish_reason: Option, +} + +/// Log probability information for a choice. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CompletionLogprobs { + pub text_offset: Vec, + pub token_logprobs: Vec, + pub tokens: Vec, + pub top_logprobs: Option>>, +} diff --git a/src/clients/tgis.rs b/src/clients/tgis.rs index bc7412de..040e61e9 100644 --- a/src/clients/tgis.rs +++ b/src/clients/tgis.rs @@ -14,18 +14,15 @@ limitations under the License. */ -use std::collections::HashMap; - +use async_trait::async_trait; use axum::http::HeaderMap; use futures::{StreamExt, TryStreamExt}; use ginepro::LoadBalancedChannel; use tonic::Code; -use super::{create_grpc_clients, BoxStream, ClientCode, Error}; +use super::{BoxStream, Client, ClientCode, Error}; use crate::{ - clients::COMMON_ROUTER_KEY, - config::ServiceConfig, - health::{HealthCheckResult, HealthProbe, HealthStatus}, + health::{HealthCheckResult, HealthStatus}, pb::fmaas::{ generation_service_client::GenerationServiceClient, BatchedGenerationRequest, BatchedGenerationResponse, BatchedTokenizeRequest, BatchedTokenizeResponse, @@ -33,67 +30,16 @@ use crate::{ }, }; -#[cfg_attr(test, faux::create, derive(Default))] +#[cfg_attr(test, faux::create)] #[derive(Clone)] pub struct TgisClient { - clients: HashMap>, -} - -#[cfg_attr(test, faux::methods)] -impl HealthProbe for TgisClient { - async fn health(&self) -> Result, Error> { - let mut results = HashMap::with_capacity(self.clients.len()); - for (model_id, mut client) in self.clients.clone() { - let response = client - .model_info(ModelInfoRequest { - model_id: "".into(), - }) - .await; - let code = match response { - Ok(_) => Code::Ok, - Err(status) if matches!(status.code(), Code::InvalidArgument | Code::NotFound) => { - Code::Ok - } - Err(status) => status.code(), - }; - let health_status = if matches!(code, Code::Ok) { - HealthStatus::Healthy - } else { - HealthStatus::Unhealthy - }; - results.insert( - model_id, - HealthCheckResult { - health_status, - response_code: ClientCode::Grpc(code), - reason: None, - }, - ); - } - Ok(results) - } + client: GenerationServiceClient, } #[cfg_attr(test, faux::methods)] impl TgisClient { - pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self { - let clients = create_grpc_clients(default_port, config, GenerationServiceClient::new).await; - Self { clients } - } - - fn client( - &self, - _model_id: &str, - ) -> Result, Error> { - // NOTE: We currently forward requests to the common-router, so we use a single client. - let model_id = COMMON_ROUTER_KEY; - Ok(self - .clients - .get(model_id) - .ok_or_else(|| Error::ModelNotFound { - model_id: model_id.to_string(), - })? - .clone()) + pub fn new(client: GenerationServiceClient) -> Self { + Self { client } } pub async fn generate( @@ -101,8 +47,8 @@ impl TgisClient { request: BatchedGenerationRequest, _headers: HeaderMap, ) -> Result { - let model_id = request.model_id.as_str(); - Ok(self.client(model_id)?.generate(request).await?.into_inner()) + let mut client = self.client.clone(); + Ok(client.generate(request).await?.into_inner()) } pub async fn generate_stream( @@ -110,9 +56,8 @@ impl TgisClient { request: SingleGenerationRequest, _headers: HeaderMap, ) -> Result>, Error> { - let model_id = request.model_id.as_str(); - let response_stream = self - .client(model_id)? + let mut client = self.client.clone(); + let response_stream = client .generate_stream(request) .await? .into_inner() @@ -126,17 +71,47 @@ impl TgisClient { request: BatchedTokenizeRequest, _headers: HeaderMap, ) -> Result { - let model_id = request.model_id.as_str(); - Ok(self.client(model_id)?.tokenize(request).await?.into_inner()) + let mut client = self.client.clone(); + Ok(client.tokenize(request).await?.into_inner()) } pub async fn model_info(&self, request: ModelInfoRequest) -> Result { - let model_id = request.model_id.as_str(); - Ok(self - .client(model_id)? - .model_info(request) - .await? - .into_inner()) + let mut client = self.client.clone(); + Ok(client.model_info(request).await?.into_inner()) + } +} + +#[cfg_attr(test, faux::methods)] +#[async_trait] +impl Client for TgisClient { + fn name(&self) -> &str { + "tgis" + } + + async fn health(&self) -> HealthCheckResult { + let mut client = self.client.clone(); + let response = client + .model_info(ModelInfoRequest { + model_id: "".into(), + }) + .await; + let code = match response { + Ok(_) => Code::Ok, + Err(status) if matches!(status.code(), Code::InvalidArgument | Code::NotFound) => { + Code::Ok + } + Err(status) => status.code(), + }; + let health_status = if matches!(code, Code::Ok) { + HealthStatus::Healthy + } else { + HealthStatus::Unhealthy + }; + HealthCheckResult { + health_status, + response_code: ClientCode::Grpc(code), + reason: None, + } } } diff --git a/src/config.rs b/src/config.rs index 1ebc7879..7cef45d8 100644 --- a/src/config.rs +++ b/src/config.rs @@ -23,7 +23,7 @@ use std::{ use serde::Deserialize; use tracing::{debug, error, info, warn}; -use crate::clients::chunker::DEFAULT_MODEL_ID; +use crate::clients::{chunker::DEFAULT_MODEL_ID, http::is_base_url}; // Placeholder to add default allowed headers const DEFAULT_ALLOWED_HEADERS: &[&str] = &[]; @@ -47,6 +47,10 @@ pub enum Error { detector_id: String, chunker_id: String, }, + #[error("invalid generation provider: {0}")] + InvalidGenerationProvider(String), + #[error("invalid hostname: {0}")] + InvalidHostname(String), } /// Configuration for service needed for @@ -84,14 +88,17 @@ pub struct TlsConfig { /// Generation service provider #[cfg_attr(test, derive(Default))] #[derive(Clone, Copy, Debug, Deserialize)] -#[serde(rename_all = "lowercase")] pub enum GenerationProvider { #[cfg_attr(test, default)] + #[serde(rename = "tgis")] Tgis, + #[serde(rename = "nlp")] Nlp, + #[serde(rename = "openai")] + OpenAi, } -/// Generate service configuration +/// Generation service configuration #[cfg_attr(test, derive(Default))] #[derive(Clone, Debug, Deserialize)] pub struct GenerationConfig { @@ -101,6 +108,16 @@ pub struct GenerationConfig { pub service: ServiceConfig, } +/// Chat generation service configuration +#[cfg_attr(test, derive(Default))] +#[derive(Clone, Debug, Deserialize)] +pub struct ChatGenerationConfig { + /// Generation service provider + pub provider: GenerationProvider, + /// Generation service connection information + pub service: ServiceConfig, +} + /// Chunker parser type #[cfg_attr(test, derive(Default))] #[derive(Clone, Copy, Debug, Deserialize)] @@ -131,6 +148,20 @@ pub struct DetectorConfig { pub chunker_id: String, /// Default threshold with which to filter detector results by score pub default_threshold: f64, + /// Type of detection this detector performs + #[serde(rename = "type")] + pub r#type: DetectorType, +} + +#[derive(Clone, Debug, Default, Deserialize)] +#[serde(rename_all = "snake_case")] +#[non_exhaustive] +pub enum DetectorType { + #[default] + TextContents, + TextGeneration, + TextContextChat, + TextContextDoc, } /// Overall orchestrator server configuration @@ -139,6 +170,8 @@ pub struct DetectorConfig { pub struct OrchestratorConfig { /// Generation service and associated configuration, can be omitted if configuring for generation is not wanted pub generation: Option, + /// Chat generation service and associated configuration, can be omitted if configuring for chat generation is not wanted + pub chat_generation: Option, /// Chunker services and associated configurations, if omitted the default value "whole_doc_chunker" is used pub chunkers: Option>, /// Detector services and associated configurations @@ -206,6 +239,10 @@ impl OrchestratorConfig { if let Some(generation) = &mut self.generation { apply_named_tls_config(&mut generation.service, tls_configs)?; } + // Chat generation + if let Some(chat_generation) = &mut self.chat_generation { + apply_named_tls_config(&mut chat_generation.service, tls_configs)?; + } // Chunkers if let Some(chunkers) = &mut self.chunkers { for chunker in chunkers.values_mut() { @@ -221,25 +258,66 @@ impl OrchestratorConfig { } fn validate(&self) -> Result<(), Error> { + // Detectors are configured if self.detectors.is_empty() { - Err(Error::NoDetectorsConfigured) - } else { - for (detector_id, detector) in &self.detectors { - // Chunker is valid - let valid_chunker = detector.chunker_id == DEFAULT_MODEL_ID - || self - .chunkers - .as_ref() - .is_some_and(|chunkers| chunkers.contains_key(&detector.chunker_id)); - if !valid_chunker { - return Err(Error::DetectorChunkerNotFound { - detector_id: detector_id.clone(), - chunker_id: detector.chunker_id.clone(), - }); - } + return Err(Error::NoDetectorsConfigured); + } + + // Detector configs are valid + for (detector_id, detector) in &self.detectors { + // Hostname is valid + if !is_base_url(&detector.service.hostname) { + return Err(Error::InvalidHostname(format!( + "detector `{detector_id}` has an invalid hostname; \ + must be a base url, e.g. `https://service.route.example.com" + ))); + } + // Chunker is valid + let valid_chunker = detector.chunker_id == DEFAULT_MODEL_ID + || self + .chunkers + .as_ref() + .is_some_and(|chunkers| chunkers.contains_key(&detector.chunker_id)); + if !valid_chunker { + return Err(Error::DetectorChunkerNotFound { + detector_id: detector_id.clone(), + chunker_id: detector.chunker_id.clone(), + }); + } + } + + // Generation config is valid + if let Some(generation) = &self.generation { + // Provider is valid + if !matches!( + generation.provider, + GenerationProvider::Tgis | GenerationProvider::Nlp + ) { + return Err(Error::InvalidGenerationProvider( + "`generation` requires `tgis` or `nlp` provider".into(), + )); + } + } + + // Chat generation config is valid + if let Some(chat_generation) = &self.chat_generation { + // Provider is valid + if !matches!(chat_generation.provider, GenerationProvider::OpenAi) { + return Err(Error::InvalidGenerationProvider( + "`chat_generation` requires `openai` provider".into(), + )); + } + // Hostname is valid + if !is_base_url(&chat_generation.service.hostname) { + return Err(Error::InvalidHostname( + "`chat_generation` has an invalid hostname; \ + must be a base url, e.g. `https://service.route.example.com" + .into(), + )); } - Ok(()) } + + Ok(()) } /// Get ID of chunker associated with a particular detector @@ -301,6 +379,7 @@ chunkers: port: 9000 detectors: hap-en: + type: text_contents service: hostname: localhost port: 9000 @@ -341,6 +420,7 @@ chunkers: port: 9000 detectors: hap: + type: text_contents service: hostname: localhost port: 9000 @@ -393,6 +473,7 @@ chunkers: port: 9000 detectors: hap: + type: text_contents service: hostname: localhost port: 9000 @@ -487,6 +568,7 @@ chunkers: port: 9000 detectors: hap: + type: text_contents service: hostname: localhost port: 9000 @@ -527,6 +609,7 @@ chunkers: port: 9000 detectors: hap: + type: text_contents service: hostname: localhost port: 9000 @@ -543,10 +626,9 @@ tls: config .apply_named_tls_configs() .expect("Apply named TLS configs should have succeeded"); - let error = config + config .validate() .expect_err("Config should not have been validated"); - assert!(matches!(error, Error::DetectorChunkerNotFound { .. })) } #[test] diff --git a/src/health.rs b/src/health.rs index af43d94e..caae61aa 100644 --- a/src/health.rs +++ b/src/health.rs @@ -1,62 +1,108 @@ -use std::{collections::HashMap, fmt::Display, sync::Arc}; +use std::{collections::HashMap, fmt::Display}; -use axum::{ - http::StatusCode, - response::{IntoResponse, Response}, - Json, -}; +use axum::http::StatusCode; use serde::{ser::SerializeStruct, Deserialize, Serialize}; -use tokio::sync::RwLock; use tonic::Code; use tracing::{error, warn}; -use crate::{ - clients::{ClientCode, Error}, - pb::grpc::health::v1::HealthCheckResponse, -}; - -/// A health check endpoint for a singular client. -/// NOTE: Only implemented by HTTP clients, gRPC clients with health check support should use the generated `grpc::health::v1::health_client::HealthClient` service. -pub trait HealthCheck { - /// Makes a request to the client service health check endpoint and turns result into a `HealthCheckResult`. - fn check(&self) -> impl std::future::Future + Send; -} - -/// A health probe for aggregated health check results of multiple client services. -pub trait HealthProbe { - /// Makes a health check request to each client and returns a map of client service ids to health check results. - fn health( - &self, - ) -> impl std::future::Future, Error>> + Send; -} +use crate::{clients::ClientCode, pb::grpc::health::v1::HealthCheckResponse}; /// Health status determined for or returned by a client service. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "UPPERCASE")] pub enum HealthStatus { - /// The service is healthy and should be considered ready to serve requests. - #[serde(rename = "HEALTHY")] + /// The service status is healthy. Healthy, - /// The service is unhealthy and should be considered not ready to serve requests. - #[serde(rename = "UNHEALTHY")] + /// The service status is unhealthy. Unhealthy, - /// The health status of the service (and possibly the service itself) is unknown. - /// The health check response indicated the service's health is unknown or the health request failed in a way that could have been a misconfiguration, - /// meaning the actual service could still be healthy. - #[serde(rename = "UNKNOWN")] + /// The service status is unknown. Unknown, } -/// An optional response body that can be interpreted from an HTTP health check response. -/// This is a minimal contract that allows HTTP health requests to opt in to more detailed health check responses than just the status code. -/// If the body omitted, the health check response is considered successful if the status code is `HTTP 200 OK`. -#[derive(serde::Deserialize)] -pub struct OptionalHealthCheckResponseBody { - /// `HEALTHY`, `UNHEALTHY`, or `UNKNOWN`. Although `HEALTHY` is already implied without a body. - pub health_status: HealthStatus, - /// Optional reason for the health check result status being `UNHEALTHY` or `UNKNOWN`. - /// May be omitted overall if the health check was successful. - #[serde(default)] - pub reason: Option, +impl Display for HealthStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + HealthStatus::Healthy => write!(f, "HEALTHY"), + HealthStatus::Unhealthy => write!(f, "UNHEALTHY"), + HealthStatus::Unknown => write!(f, "UNKNOWN"), + } + } +} + +impl From for HealthStatus { + fn from(value: HealthCheckResponse) -> Self { + // NOTE: gRPC Health v1 status codes: 0 = UNKNOWN, 1 = SERVING, 2 = NOT_SERVING, 3 = SERVICE_UNKNOWN + match value.status { + 1 => Self::Healthy, + 2 => Self::Unhealthy, + _ => Self::Unknown, + } + } +} + +impl From for HealthStatus { + fn from(code: StatusCode) -> Self { + match code.as_u16() { + 200 => Self::Healthy, + 201..=299 => { + warn!( + "Unexpected HTTP successful health check response status code: {}", + code + ); + Self::Healthy + } + 503 => Self::Unhealthy, + 500..=502 | 504..=599 => { + warn!( + "Unexpected HTTP server error health check response status code: {}", + code + ); + Self::Unhealthy + } + _ => { + warn!( + "Unexpected HTTP client error health check response status code: {}", + code + ); + Self::Unknown + } + } + } +} + +/// Holds health check results for all clients. +#[derive(Debug, Clone, Default, Serialize)] +pub struct ClientHealth(HashMap); + +impl ClientHealth { + pub fn new() -> Self { + Self(HashMap::new()) + } + + pub fn with_capacity(capacity: usize) -> Self { + Self(HashMap::with_capacity(capacity)) + } + + pub fn healthy(&self) -> bool { + !self + .0 + .iter() + .any(|(_, value)| matches!(value.health_status, HealthStatus::Unhealthy)) + } +} + +impl std::ops::Deref for ClientHealth { + type Target = HashMap; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for ClientHealth { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } } /// Result of a health check request. @@ -73,29 +119,6 @@ pub struct HealthCheckResult { pub reason: Option, } -/// A cache to hold the latest health check results for each client service. -/// Orchestrator has a reference-counted mutex-protected instance of this cache. -#[derive(Debug, Clone, Default, Serialize)] -pub struct HealthCheckCache { - pub detectors: HashMap, - pub chunkers: HashMap, - pub generation: HashMap, -} - -/// Response for the readiness probe endpoint that holds a serialized cache of health check results for each client service. -#[derive(Debug, Clone, Serialize)] -pub struct HealthProbeResponse { - pub services: HealthCheckCache, -} - -/// Query param for triggering the client health check probe on the `/info` endpoint. -#[derive(Debug, Clone, Deserialize)] -pub struct HealthCheckProbeParams { - /// Whether to probe the client services' health checks or just return the cached health status. - #[serde(default)] - pub probe: bool, -} - impl HealthCheckResult { pub fn reason_from_health_check_response(response: &HealthCheckResponse) -> Option { match response.status { @@ -117,67 +140,6 @@ impl HealthCheckResult { } } -impl HealthCheckCache { - pub fn is_initialized(&self) -> bool { - !self.detectors.is_empty() && !self.chunkers.is_empty() && !self.generation.is_empty() - } -} - -impl HealthProbeResponse { - pub async fn from_cache(cache: Arc>) -> Self { - let services = cache.read().await.clone(); - Self { services } - } -} - -impl Display for HealthStatus { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - HealthStatus::Healthy => write!(f, "HEALTHY"), - HealthStatus::Unhealthy => write!(f, "UNHEALTHY"), - HealthStatus::Unknown => write!(f, "UNKNOWN"), - } - } -} - -impl Display for HealthCheckCache { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut services = vec![]; - let mut detectors = vec![]; - let mut chunkers = vec![]; - let mut generation = vec![]; - for (service, result) in &self.detectors { - detectors.push(format!("\t\t{}: {}", service, result)); - } - for (service, result) in &self.chunkers { - chunkers.push(format!("\t\t{}: {}", service, result)); - } - for (service, result) in &self.generation { - generation.push(format!("\t\t{}: {}", service, result)); - } - if !self.detectors.is_empty() { - services.push(format!("\tdetectors: {{\n{}\t}}", detectors.join(",\n"))); - } - if !self.chunkers.is_empty() { - services.push(format!("\tchunkers: {{\n{}\t}}", chunkers.join(",\n"))); - } - if !self.generation.is_empty() { - services.push(format!("\tgeneration: {{\n{}\t}}", generation.join(",\n"))); - } - write!( - f, - "configured client services: {{\n{}\n}}", - services.join(",\n") - ) - } -} - -impl Display for HealthProbeResponse { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.services) - } -} - impl Serialize for HealthCheckResult { fn serialize(&self, serializer: S) -> Result where @@ -237,49 +199,15 @@ impl From, tonic::Status>> for Healt } } -impl From for HealthStatus { - fn from(value: HealthCheckResponse) -> Self { - // NOTE: gRPC Health v1 status codes: 0 = UNKNOWN, 1 = SERVING, 2 = NOT_SERVING, 3 = SERVICE_UNKNOWN - match value.status { - 1 => Self::Healthy, - 2 => Self::Unhealthy, - _ => Self::Unknown, - } - } -} - -impl From for HealthStatus { - fn from(code: StatusCode) -> Self { - match code.as_u16() { - 200 => Self::Healthy, - 201..=299 => { - warn!( - "Unexpected HTTP successful health check response status code: {}", - code - ); - Self::Healthy - } - 503 => Self::Unhealthy, - 500..=502 | 504..=599 => { - warn!( - "Unexpected HTTP server error health check response status code: {}", - code - ); - Self::Unhealthy - } - _ => { - warn!( - "Unexpected HTTP client error health check response status code: {}", - code - ); - Self::Unknown - } - } - } -} - -impl IntoResponse for HealthProbeResponse { - fn into_response(self) -> Response { - (StatusCode::OK, Json(self)).into_response() - } +/// An optional response body that can be interpreted from an HTTP health check response. +/// This is a minimal contract that allows HTTP health requests to opt in to more detailed health check responses than just the status code. +/// If the body omitted, the health check response is considered successful if the status code is `HTTP 200 OK`. +#[derive(Deserialize)] +pub struct OptionalHealthCheckResponseBody { + /// `HEALTHY`, `UNHEALTHY`, or `UNKNOWN`. Although `HEALTHY` is already implied without a body. + pub health_status: HealthStatus, + /// Optional reason for the health check result status being `UNHEALTHY` or `UNKNOWN`. + /// May be omitted overall if the health check was successful. + #[serde(default)] + pub reason: Option, } diff --git a/src/lib.rs b/src/lib.rs index bfff8695..4ba228cb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,7 +15,7 @@ */ -#![allow(clippy::iter_kv_map, clippy::enum_variant_names)] +#![allow(clippy::iter_kv_map, clippy::enum_variant_names, async_fn_in_trait)] mod clients; pub mod config; diff --git a/src/models.rs b/src/models.rs index 1eb764ce..869e1f05 100644 --- a/src/models.rs +++ b/src/models.rs @@ -23,9 +23,22 @@ use serde::{Deserialize, Serialize}; use crate::{ clients::detector::{ContentAnalysisResponse, ContextType}, + health::ClientHealth, pb, }; +#[derive(Clone, Debug, Serialize)] +pub struct InfoResponse { + pub client_health: ClientHealth, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct InfoParams { + /// Whether to probe the client services' health checks or just return the latest health status. + #[serde(default)] + pub probe: bool, +} + /// Parameters relevant to each detector #[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)] pub struct DetectorParams(HashMap); diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 607ae96a..500ac1da 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -29,18 +29,38 @@ use uuid::Uuid; use crate::{ clients::{ - self, detector::ContextType, ChunkerClient, DetectorClient, GenerationClient, NlpClient, - TgisClient, COMMON_ROUTER_KEY, + create_grpc_client, create_http_client, + detector::{ + text_context_doc::ContextType, TextContextChatDetectorClient, + TextContextDocDetectorClient, TextGenerationDetectorClient, + }, + openai::OpenAiClient, + ChunkerClient, ClientMap, GenerationClient, NlpClient, TextContentsDetectorClient, + TgisClient, }, - config::{GenerationProvider, OrchestratorConfig}, - health::{HealthCheckCache, HealthProbe, HealthProbeResponse}, + config::{DetectorType, GenerationProvider, OrchestratorConfig}, + health::ClientHealth, models::{ ContextDocsHttpRequest, DetectionOnGeneratedHttpRequest, DetectorParams, GenerationWithDetectionHttpRequest, GuardrailsConfig, GuardrailsHttpRequest, GuardrailsTextGenerationParameters, TextContentDetectionHttpRequest, }, + pb::{ + caikit::runtime::{ + chunkers::chunkers_service_client::ChunkersServiceClient, + nlp::nlp_service_client::NlpServiceClient, + }, + fmaas::generation_service_client::GenerationServiceClient, + grpc::health::v1::health_client::HealthClient, + }, }; +const DEFAULT_TGIS_PORT: u16 = 8033; +const DEFAULT_NLP_PORT: u16 = 8085; +const DEFAULT_CHUNKER_PORT: u16 = 8085; +const DEFAULT_OPENAI_PORT: u16 = 8080; +const DEFAULT_DETECTOR_PORT: u16 = 8080; + const UNSUITABLE_INPUT_MESSAGE: &str = "Unsuitable input detected. \ Please check the detected entities on your input and try again \ with the unsuitable input removed."; @@ -48,16 +68,20 @@ const UNSUITABLE_INPUT_MESSAGE: &str = "Unsuitable input detected. \ #[cfg_attr(test, derive(Default))] pub struct Context { config: OrchestratorConfig, - generation_client: GenerationClient, - chunker_client: ChunkerClient, - detector_client: DetectorClient, + clients: ClientMap, +} + +impl Context { + pub fn new(config: OrchestratorConfig, clients: ClientMap) -> Self { + Self { config, clients } + } } /// Handles orchestrator tasks. #[cfg_attr(test, derive(Default))] pub struct Orchestrator { ctx: Arc, - client_health_cache: Arc>, + client_health: Arc>, } impl Orchestrator { @@ -65,16 +89,11 @@ impl Orchestrator { config: OrchestratorConfig, start_up_health_check: bool, ) -> Result { - let (generation_client, chunker_client, detector_client) = create_clients(&config).await; - let ctx = Arc::new(Context { - config, - generation_client, - chunker_client, - detector_client, - }); + let clients = create_clients(&config).await; + let ctx = Arc::new(Context { config, clients }); let orchestrator = Self { ctx, - client_health_cache: Arc::new(RwLock::new(HealthCheckCache::default())), + client_health: Arc::new(RwLock::new(ClientHealth::default())), }; debug!("running start up checks"); orchestrator.on_start_up(start_up_health_check).await?; @@ -92,37 +111,34 @@ impl Orchestrator { pub async fn on_start_up(&self, health_check: bool) -> Result<(), Error> { info!("Performing start-up actions for orchestrator..."); if health_check { - info!("Probing health status of configured clients..."); - // Run probe, update cache - let res = self.clients_health(true).await.unwrap_or_else(|e| { - // Panic for unexpected behaviour as there are currently no errors propagated to here. - panic!("Unexpected error during client health probing: {}", e); - }); + info!("Probing client health..."); + let client_health = self.client_health(true).await; // Results of probe do not affect orchestrator start-up. - info!("Orchestrator client health probe results:\n{}", res); + info!("Client health: {client_health:?}"); // TODO: re-impl Display } Ok(()) } - pub async fn clients_health(&self, probe: bool) -> Result { - let initialized = self.client_health_cache.read().await.is_initialized(); + /// Returns client health state. + pub async fn client_health(&self, probe: bool) -> ClientHealth { + let initialized = !self.client_health.read().await.is_empty(); if probe || !initialized { - debug!("refreshing health cache"); + debug!("refreshing client health"); let now = Instant::now(); - let detectors = self.ctx.detector_client.health().await?; - let chunkers = self.ctx.chunker_client.health().await?; - let generation = self.ctx.generation_client.health().await?; - let mut health_cache = self.client_health_cache.write().await; - health_cache.detectors = detectors; - health_cache.chunkers = chunkers; - health_cache.generation = generation; + let mut state = ClientHealth::with_capacity(self.ctx.clients.len()); + // TODO: perform health checks concurrently? + for (key, client) in self.ctx.clients.iter() { + let result = client.health().await; + state.insert(key.into(), result); + } + let mut client_health = self.client_health.write().await; + *client_health = state; debug!( - "refreshing health cache completed in {:.2?}ms", + "refreshing client health completed in {:.2?}ms", now.elapsed().as_millis() ); } - - Ok(HealthProbeResponse::from_cache(self.client_health_cache.clone()).await) + self.client_health.read().await.clone() } } @@ -162,50 +178,98 @@ fn get_chunker_ids( .collect::, Error>>() } -async fn create_clients( - config: &OrchestratorConfig, -) -> (GenerationClient, ChunkerClient, DetectorClient) { - // TODO: create better solution for routers - let generation_client = match &config.generation { - Some(generation) => match &generation.provider { +async fn create_clients(config: &OrchestratorConfig) -> ClientMap { + let mut clients = ClientMap::new(); + + // Create generation client + if let Some(generation) = &config.generation { + match generation.provider { GenerationProvider::Tgis => { - let client = TgisClient::new( - clients::DEFAULT_TGIS_PORT, - &[(COMMON_ROUTER_KEY.to_string(), generation.service.clone())], + let client = create_grpc_client( + DEFAULT_TGIS_PORT, + &generation.service, + GenerationServiceClient::new, ) .await; - GenerationClient::tgis(client) + let tgis_client = TgisClient::new(client); + let generation_client = GenerationClient::tgis(tgis_client); + clients.insert("generation".to_string(), generation_client); } GenerationProvider::Nlp => { - let client = NlpClient::new( - clients::DEFAULT_CAIKIT_NLP_PORT, - &[(COMMON_ROUTER_KEY.to_string(), generation.service.clone())], + let client = create_grpc_client( + DEFAULT_NLP_PORT, + &generation.service, + NlpServiceClient::new, ) .await; - GenerationClient::nlp(client) + let health_client = + create_grpc_client(DEFAULT_NLP_PORT, &generation.service, HealthClient::new) + .await; + let nlp_client = NlpClient::new(client, health_client); + let generation_client = GenerationClient::nlp(nlp_client); + clients.insert("generation".to_string(), generation_client); } - }, - None => GenerationClient::not_configured(), - }; - // TODO: simplify all of this - let chunker_config = match &config.chunkers { - Some(chunkers) => chunkers - .iter() - .map(|(chunker_id, config)| (chunker_id.clone(), config.service.clone())) - .collect::>(), - None => vec![], - }; - let chunker_client = ChunkerClient::new(clients::DEFAULT_CHUNKER_PORT, &chunker_config).await; - - let detector_config = config - .detectors - .iter() - .map(|(detector_id, config)| (detector_id.clone(), config.service.clone())) - .collect::>(); - let detector_client = - DetectorClient::new(clients::DEFAULT_DETECTOR_PORT, &detector_config).await; - - (generation_client, chunker_client, detector_client) + GenerationProvider::OpenAi => unimplemented!(), + } + } + + // Create chat generation client + if let Some(chat_generation) = &config.chat_generation { + match chat_generation.provider { + GenerationProvider::OpenAi => { + let client = + create_http_client(DEFAULT_OPENAI_PORT, &chat_generation.service).await; + let openai_client = OpenAiClient::new(client); + clients.insert("chat_generation".to_string(), openai_client); + } + _ => unimplemented!(), + } + } + + // Create chunker clients + if let Some(chunkers) = &config.chunkers { + for (chunker_id, chunker) in chunkers { + let client = create_grpc_client( + DEFAULT_CHUNKER_PORT, + &chunker.service, + ChunkersServiceClient::new, + ) + .await; + let health_client = + create_grpc_client(DEFAULT_CHUNKER_PORT, &chunker.service, HealthClient::new).await; + let chunker_client = ChunkerClient::new(client, health_client); + clients.insert(chunker_id.to_string(), chunker_client); + } + } + + // Create detector clients + for (detector_id, detector) in &config.detectors { + let client = create_http_client(DEFAULT_DETECTOR_PORT, &detector.service).await; + match detector.r#type { + DetectorType::TextContents => { + clients.insert(detector_id.into(), TextContentsDetectorClient::new(client)); + } + DetectorType::TextGeneration => { + clients.insert( + detector_id.into(), + TextGenerationDetectorClient::new(client), + ); + } + DetectorType::TextContextChat => { + clients.insert( + detector_id.into(), + TextContextChatDetectorClient::new(client), + ); + } + DetectorType::TextContextDoc => { + clients.insert( + detector_id.into(), + TextContextDocDetectorClient::new(client), + ); + } + } + } + clients } #[derive(Debug, Clone)] diff --git a/src/orchestrator/streaming.rs b/src/orchestrator/streaming.rs index 0e738d83..b8697df1 100644 --- a/src/orchestrator/streaming.rs +++ b/src/orchestrator/streaming.rs @@ -22,14 +22,16 @@ use std::{collections::HashMap, pin::Pin, sync::Arc, time::Duration}; use aggregator::Aggregator; use axum::http::HeaderMap; use futures::{future::try_join_all, Stream, StreamExt, TryStreamExt}; - use tokio::sync::{broadcast, mpsc}; use tokio_stream::wrappers::{BroadcastStream, ReceiverStream}; use tracing::{debug, error, info, instrument}; use super::{get_chunker_ids, Context, Error, Orchestrator, StreamingClassificationWithGenTask}; use crate::{ - clients::detector::ContentAnalysisRequest, + clients::{ + detector::ContentAnalysisRequest, ChunkerClient, GenerationClient, + TextContentsDetectorClient, + }, models::{ ClassifiedGeneratedTextStreamResult, DetectorParams, GuardrailsTextGenerationParameters, InputWarning, InputWarningReason, TextGenTokenClassificationResults, @@ -39,8 +41,7 @@ use crate::{ unary::{input_detection_task, tokenize}, UNSUITABLE_INPUT_MESSAGE, }, - pb::caikit::runtime::chunkers, - pb::caikit_data_model::nlp::ChunkerTokenizationStreamResult, + pb::{caikit::runtime::chunkers, caikit_data_model::nlp::ChunkerTokenizationStreamResult}, }; pub type Chunk = ChunkerTokenizationStreamResult; @@ -381,9 +382,11 @@ async fn detection_task( let request = ContentAnalysisRequest::new(contents.clone()); let headers = headers.clone(); debug!(%detector_id, ?request, "sending detector request"); - match ctx - .detector_client - .text_contents(&detector_id, request, headers) + let client = ctx + .clients + .get_as::(&detector_id) + .unwrap(); + match client.text_contents(&detector_id, request, headers) .await .map_err(|error| Error::DetectorRequestFailed { id: detector_id.clone(), error }) { Ok(response) => { @@ -452,8 +455,8 @@ async fn chunk_broadcast_task( .boxed(); debug!(%chunker_id, "creating chunker output stream"); let id = chunker_id.clone(); // workaround for StreamExt::map_err - let mut output_stream = ctx - .chunker_client + let client = ctx.clients.get_as::(&chunker_id).unwrap(); + let mut output_stream = client .bidi_streaming_tokenization_task_predict(&chunker_id, input_stream) .await .map_err(|error| Error::ChunkerRequestFailed { @@ -511,8 +514,11 @@ async fn generate_stream( Pin> + Send>>, Error, > { - Ok(ctx - .generation_client + let client = ctx + .clients + .get_as::("generation") + .unwrap(); + Ok(client .generate_stream(model_id.clone(), text, params, headers) .await .map_err(|error| Error::GenerateRequestFailed { diff --git a/src/orchestrator/unary.rs b/src/orchestrator/unary.rs index 4c30e012..149695e5 100644 --- a/src/orchestrator/unary.rs +++ b/src/orchestrator/unary.rs @@ -30,9 +30,13 @@ use super::{ Orchestrator, TextContentDetectionTask, }; use crate::{ - clients::detector::{ - ContentAnalysisRequest, ContentAnalysisResponse, ContextDocsDetectionRequest, ContextType, - GenerationDetectionRequest, + clients::{ + detector::{ + ContentAnalysisRequest, ContentAnalysisResponse, ContextDocsDetectionRequest, + ContextType, GenerationDetectionRequest, TextContentsDetectorClient, + TextContextDocDetectorClient, TextGenerationDetectorClient, + }, + ChunkerClient, GenerationClient, }, models::{ ClassifiedGeneratedTextResult, ContextDocsResult, DetectionOnGenerationResult, @@ -567,7 +571,11 @@ pub async fn detect( } else { let request = ContentAnalysisRequest::new(contents); debug!(%detector_id, ?request, "sending detector request"); - ctx.detector_client + let client = ctx + .clients + .get_as::(&detector_id) + .unwrap(); + client .text_contents(&detector_id, request, headers) .await .map_err(|error| { @@ -622,7 +630,11 @@ pub async fn detect_content( } else { let request = ContentAnalysisRequest::new(contents); debug!(%detector_id, ?request, "sending detector request"); - ctx.detector_client + let client = ctx + .clients + .get_as::(&detector_id) + .unwrap(); + client .text_contents(&detector_id, request, headers) .await .map_err(|error| { @@ -677,8 +689,11 @@ pub async fn detect_for_generation( ); let request = GenerationDetectionRequest::new(prompt.clone(), generated_text.clone()); debug!(%detector_id, ?request, "sending generation detector request"); - let response = ctx - .detector_client + let client = ctx + .clients + .get_as::(&detector_id) + .unwrap(); + let response = client .text_generation(&detector_id, request, headers) .await .map(|results| { @@ -717,9 +732,12 @@ pub async fn detect_for_context( ); let request = ContextDocsDetectionRequest::new(content, context_type, context, detector_params); debug!(%detector_id, ?request, "sending context detector request"); - let response = ctx - .detector_client - .text_context_doc(&detector_id, request, headers) + let client = ctx + .clients + .get_as::(&detector_id) + .unwrap(); + let response = client + .text_context_docs(&detector_id, request, headers) .await .map(|results| { results @@ -745,8 +763,8 @@ pub async fn chunk( ) -> Result, Error> { let request = chunkers::ChunkerTokenizationTaskRequest { text }; debug!(%chunker_id, ?request, "sending chunker request"); - let response = ctx - .chunker_client + let client = ctx.clients.get_as::(&chunker_id).unwrap(); + let response = client .tokenization_task_predict(&chunker_id, request) .await .map_err(|error| Error::ChunkerRequestFailed { @@ -797,7 +815,11 @@ pub async fn tokenize( text: String, headers: HeaderMap, ) -> Result<(u32, Vec), Error> { - ctx.generation_client + let client = ctx + .clients + .get_as::("generation") + .unwrap(); + client .tokenize(model_id.clone(), text, headers) .await .map_err(|error| Error::TokenizeRequestFailed { @@ -814,7 +836,11 @@ async fn generate( params: Option, headers: HeaderMap, ) -> Result { - ctx.generation_client + let client = ctx + .clients + .get_as::("generation") + .unwrap(); + client .generate(model_id.clone(), text, params, headers) .await .map_err(|error| Error::GenerateRequestFailed { @@ -832,7 +858,7 @@ mod tests { clients::{ self, detector::{ContentAnalysisResponse, GenerationDetectionRequest}, - ChunkerClient, DetectorClient, GenerationClient, TgisClient, + ClientMap, GenerationClient, TgisClient, }, config::{DetectorConfig, OrchestratorConfig}, models::{DetectionResult, EvidenceObj, FinishReason}, @@ -842,27 +868,10 @@ mod tests { }, }; - async fn get_test_context( - gen_client: GenerationClient, - chunker_client: Option, - detector_client: Option, - ) -> Context { - let chunker_client = chunker_client.unwrap_or_default(); - let detector_client = detector_client.unwrap_or_default(); - - Context { - generation_client: gen_client, - chunker_client, - detector_client, - config: OrchestratorConfig::default(), - } - } - // Test for TGIS generation with default parameter #[tokio::test] async fn test_tgis_generate_with_default_params() { - // Initialize a mock object from `TgisClient` - let mut mock_client = TgisClient::faux(); + let mut tgis_client = TgisClient::faux(); let sample_text = String::from("sample text"); let text_gen_model_id = String::from("test-llm-id-1"); @@ -899,13 +908,16 @@ mod tests { }; // Construct a behavior for the mock object - faux::when!(mock_client.generate(expected_generate_req_args, HeaderMap::new())) + let headers = HeaderMap::new(); + faux::when!(tgis_client.generate(expected_generate_req_args, headers)) .once() // TODO: Add with_args .then_return(Ok(client_generation_response)); - let mock_generation_client = GenerationClient::tgis(mock_client.clone()); + let generation_client = GenerationClient::tgis(tgis_client.clone()); - let ctx = Arc::new(get_test_context(mock_generation_client, None, None).await); + let mut clients = ClientMap::new(); + clients.insert("generation".into(), generation_client); + let ctx = Arc::new(Context::new(OrchestratorConfig::default(), clients)); // Test request formulation and response processing is as expected assert_eq!( @@ -925,8 +937,8 @@ mod tests { /// 2. detections below the threshold are not returned to the client. #[tokio::test] async fn test_handle_detection_task() { - let mock_generation_client = GenerationClient::tgis(TgisClient::faux()); - let mut mock_detector_client = DetectorClient::faux(); + let generation_client = GenerationClient::tgis(TgisClient::faux()); + let mut detector_client = TextContentsDetectorClient::faux(); let detector_id = "mocked_hap_detector"; let threshold = 0.5; @@ -957,7 +969,7 @@ mod tests { token_count: None, }]; - faux::when!(mock_detector_client.text_contents( + faux::when!(detector_client.text_contents( detector_id, ContentAnalysisRequest::new(vec![first_sentence.clone(), second_sentence.clone()]), HeaderMap::new(), @@ -984,12 +996,14 @@ mod tests { }], ])); - let ctx: Context = - get_test_context(mock_generation_client, None, Some(mock_detector_client)).await; + let mut clients = ClientMap::new(); + clients.insert("generation".into(), generation_client); + clients.insert(detector_id.into(), detector_client); + let ctx = Arc::new(Context::new(OrchestratorConfig::default(), clients)); assert_eq!( detect( - ctx.into(), + ctx, detector_id.to_string(), threshold, detector_params, @@ -1005,8 +1019,8 @@ mod tests { /// This test checks if calls to detectors returning 503 are being propagated in the orchestrator response. #[tokio::test] async fn test_detect_when_detector_returns_503() { - let mock_generation_client = GenerationClient::tgis(TgisClient::faux()); - let mut mock_detector_client = DetectorClient::faux(); + let generation_client = GenerationClient::tgis(TgisClient::faux()); + let mut detector_client = TextContentsDetectorClient::faux(); let detector_id = "mocked_503_detector"; let sentence = "This call will return a 503.".to_string(); @@ -1027,7 +1041,7 @@ mod tests { }, }; - faux::when!(mock_detector_client.text_contents( + faux::when!(detector_client.text_contents( detector_id, ContentAnalysisRequest::new(vec![sentence.clone()]), HeaderMap::new(), @@ -1038,12 +1052,14 @@ mod tests { message: "Service Unavailable".to_string(), })); - let ctx: Context = - get_test_context(mock_generation_client, None, Some(mock_detector_client)).await; + let mut clients = ClientMap::new(); + clients.insert("generation".into(), generation_client); + clients.insert(detector_id.into(), detector_client); + let ctx = Arc::new(Context::new(OrchestratorConfig::default(), clients)); assert_eq!( detect( - ctx.into(), + ctx, detector_id.to_string(), threshold, detector_params, @@ -1055,10 +1071,11 @@ mod tests { expected_response ); } + #[tokio::test] async fn test_handle_detection_task_with_whitespace() { - let mock_generation_client = GenerationClient::tgis(TgisClient::faux()); - let mut mock_detector_client = DetectorClient::faux(); + let generation_client = GenerationClient::tgis(TgisClient::faux()); + let mut detector_client = TextContentsDetectorClient::faux(); let detector_id = "mocked_hap_detector"; let threshold = 0.5; @@ -1070,7 +1087,7 @@ mod tests { text: first_sentence.clone(), }]; - faux::when!(mock_detector_client.text_contents( + faux::when!(detector_client.text_contents( detector_id, ContentAnalysisRequest::new(vec![first_sentence.clone()]), HeaderMap::new(), @@ -1078,12 +1095,15 @@ mod tests { .once() .then_return(Ok(vec![vec![]])); - let ctx: Context = - get_test_context(mock_generation_client, None, Some(mock_detector_client)).await; + let mut clients = ClientMap::new(); + clients.insert("generation".into(), generation_client); + clients.insert(detector_id.into(), detector_client); + let ctx = Arc::new(Context::new(OrchestratorConfig::default(), clients)); + let expected_response_whitespace = vec![]; assert_eq!( detect( - ctx.into(), + ctx, detector_id.to_string(), threshold, detector_params, @@ -1095,11 +1115,11 @@ mod tests { expected_response_whitespace ); } - /// This test checks if calls to detectors for the /generation-detection endpoint are being handled appropriately. + #[tokio::test] async fn test_detect_for_generation() { - let mock_generation_client = GenerationClient::tgis(TgisClient::faux()); - let mut mock_detector_client = DetectorClient::faux(); + let generation_client = GenerationClient::tgis(TgisClient::faux()); + let mut detector_client = TextGenerationDetectorClient::faux(); let detector_id = "mocked_answer_relevance_detector"; let threshold = 0.5; @@ -1123,7 +1143,7 @@ mod tests { ), }]; - faux::when!(mock_detector_client.text_generation( + faux::when!(detector_client.text_generation( detector_id, GenerationDetectionRequest::new(prompt.clone(), generated_text.clone()), HeaderMap::new(), @@ -1144,9 +1164,10 @@ mod tests { ), }])); - let mut ctx: Context = - get_test_context(mock_generation_client, None, Some(mock_detector_client)).await; - + let mut clients = ClientMap::new(); + clients.insert("generation".into(), generation_client); + clients.insert(detector_id.into(), detector_client); + let mut ctx = Context::new(OrchestratorConfig::default(), clients); // add detector ctx.config.detectors.insert( detector_id.to_string(), @@ -1154,10 +1175,11 @@ mod tests { ..Default::default() }, ); + let ctx = Arc::new(ctx); assert_eq!( detect_for_generation( - ctx.into(), + ctx, detector_id.to_string(), detector_params, prompt, @@ -1170,11 +1192,10 @@ mod tests { ); } - /// This test checks if calls to detectors for the /generation-detection endpoint only return detections above the threshold. #[tokio::test] async fn test_detect_for_generation_below_threshold() { - let mock_generation_client = GenerationClient::tgis(TgisClient::faux()); - let mut mock_detector_client = DetectorClient::faux(); + let generation_client = GenerationClient::tgis(TgisClient::faux()); + let mut detector_client = TextGenerationDetectorClient::faux(); let detector_id = "mocked_answer_relevance_detector"; let threshold = 0.5; @@ -1186,7 +1207,7 @@ mod tests { let expected_response: Vec = vec![]; - faux::when!(mock_detector_client.text_generation( + faux::when!(detector_client.text_generation( detector_id, GenerationDetectionRequest::new(prompt.clone(), generated_text.clone()), HeaderMap::new(), @@ -1199,20 +1220,22 @@ mod tests { evidence: None, }])); - let mut ctx: Context = - get_test_context(mock_generation_client, None, Some(mock_detector_client)).await; - - // add mocked detector + let mut clients = ClientMap::new(); + clients.insert("generation".into(), generation_client); + clients.insert(detector_id.into(), detector_client); + let mut ctx = Context::new(OrchestratorConfig::default(), clients); + // add detector ctx.config.detectors.insert( detector_id.to_string(), DetectorConfig { ..Default::default() }, ); + let ctx = Arc::new(ctx); assert_eq!( detect_for_generation( - ctx.into(), + ctx, detector_id.to_string(), detector_params, prompt, diff --git a/src/server.rs b/src/server.rs index f88f861d..6343d29a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -49,8 +49,7 @@ use uuid::Uuid; use webpki::types::{CertificateDer, PrivateKeyDer}; use crate::{ - health::HealthCheckProbeParams, - models, + models::{self, InfoParams, InfoResponse}, orchestrator::{ self, ClassificationWithGenTask, ContextDocsDetectionTask, DetectionOnGenerationTask, GenerationWithDetectionTask, Orchestrator, StreamingClassificationWithGenTask, @@ -294,18 +293,10 @@ async fn health() -> Result { async fn info( State(state): State>, - Query(params): Query, -) -> Result { - match state.orchestrator.clients_health(params.probe).await { - Ok(client_health_info) => Ok(client_health_info), - Err(error) => { - error!( - "Unexpected internal error while checking client health info: {:?}", - error - ); - Err(error.into()) - } - } + Query(params): Query, +) -> Result, Error> { + let client_health = state.orchestrator.client_health(params.probe).await; + Ok(Json(InfoResponse { client_health })) } async fn classification_with_gen( diff --git a/tests/test.config.yaml b/tests/test.config.yaml index 6ca749db..773b85b7 100644 --- a/tests/test.config.yaml +++ b/tests/test.config.yaml @@ -11,8 +11,9 @@ chunkers: port: 8085 detectors: test_detector: + type: text_contents service: - hostname: https://localhost/api/v1/text/contents + hostname: https://localhost port: 8000 chunker_id: test_chunker default_threshold: 0.5 From 3861b12f9295257f0a8f6250bcc682609649c78a Mon Sep 17 00:00:00 2001 From: declark1 Date: Wed, 2 Oct 2024 15:19:37 -0700 Subject: [PATCH 02/50] Add is_valid_hostname(), update hostname validation, infer protocol from tls config Co-authored-by: Mateus Devino Signed-off-by: declark1 --- config/config.yaml | 4 ++-- config/test.config.yaml | 4 ++-- src/clients.rs | 50 ++++++++++++++++++++++++++++++++++++++++- src/clients/http.rs | 28 ----------------------- src/config.rs | 31 ++++++++++++++++++------- tests/test.config.yaml | 2 +- 6 files changed, 77 insertions(+), 42 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 4f13ea18..6ce600aa 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -15,7 +15,7 @@ generation: # chat_generation: # provider: openai # service: -# hostname: http://localhost +# hostname: localhost # port: 8080 # Any chunker servers that will be used by any detectors chunkers: @@ -35,7 +35,7 @@ detectors: # Detector type (text_contents, text_generation, text_context_chat, text_context_doc) type: text_contents service: - hostname: https://localhost + hostname: localhost port: 8080 # TLS ID/name, optional (detailed in `tls` section) tls: detector diff --git a/config/test.config.yaml b/config/test.config.yaml index 0decc09a..983a6234 100644 --- a/config/test.config.yaml +++ b/config/test.config.yaml @@ -6,7 +6,7 @@ generation: # chat_generation: # provider: openai # service: -# hostname: http://localhost +# hostname: localhost # port: 8080 chunkers: test_chunker: @@ -18,7 +18,7 @@ detectors: test_detector: type: text_contents service: - hostname: https://localhost + hostname: localhost port: 8000 chunker_id: test_chunker default_threshold: 0.5 diff --git a/src/clients.rs b/src/clients.rs index dc87d5e6..93da58a9 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -196,7 +196,11 @@ impl ClientMap { pub async fn create_http_client(default_port: u16, service_config: &ServiceConfig) -> HttpClient { let port = service_config.port.unwrap_or(default_port); - let mut base_url = Url::parse(&service_config.hostname).unwrap(); + let protocol = match service_config.tls { + Some(_) => "https", + None => "http", + }; + let mut base_url = Url::parse(&format!("{}://{}", protocol, &service_config.hostname)).unwrap(); base_url.set_port(Some(port)).unwrap(); let request_timeout = Duration::from_secs( service_config @@ -302,6 +306,32 @@ pub async fn create_grpc_client( new(channel) } +/// Returns `true` if hostname is valid according to [IETF RFC 1123](https://tools.ietf.org/html/rfc1123). +/// +/// Conditions: +/// - It does not start or end with `-` or `.`. +/// - It does not contain any characters outside of the alphanumeric range, except for `-` and `.`. +/// - It is not empty. +/// - It is 253 or fewer characters. +/// - Its labels (characters separated by `.`) are not empty. +/// - Its labels are 63 or fewer characters. +/// - Its labels do not start or end with '-' or '.'. +pub fn is_valid_hostname(hostname: &str) -> bool { + fn is_valid_char(byte: u8) -> bool { + byte.is_ascii_lowercase() + || byte.is_ascii_uppercase() + || byte.is_ascii_digit() + || byte == b'-' + || byte == b'.' + } + !(hostname.bytes().any(|byte| !is_valid_char(byte)) + || hostname.split('.').any(|label| { + label.is_empty() || label.len() > 63 || label.starts_with('-') || label.ends_with('-') + }) + || hostname.is_empty() + || hostname.len() > 253) +} + #[cfg(test)] mod tests { use hyper::{http, StatusCode}; @@ -576,4 +606,22 @@ mod tests { ); } } + + #[test] + fn test_is_valid_hostname() { + let valid_hostnames = ["localhost", "example.route.cloud.com", "127.0.0.1"]; + for hostname in valid_hostnames { + assert!(is_valid_hostname(hostname)); + } + let invalid_hostnames = [ + "-LoCaLhOST_", + ".invalid", + "invalid.ending-.char", + "@asdf", + "too-long-of-a-hostnameeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee", + ]; + for hostname in invalid_hostnames { + assert!(!is_valid_hostname(hostname)); + } + } } diff --git a/src/clients/http.rs b/src/clients/http.rs index 93c99b1d..f6d1180c 100644 --- a/src/clients/http.rs +++ b/src/clients/http.rs @@ -132,16 +132,6 @@ pub fn extract_base_url(url: &Url) -> Option { Some(url) } -/// Returns `true` if url is a valid base url. -pub fn is_base_url(url: &str) -> bool { - if let Ok(url) = Url::parse(url) { - if let Some(base_url) = extract_base_url(&url) { - return url == base_url; - } - } - false -} - #[cfg(test)] mod tests { use super::*; @@ -161,22 +151,4 @@ mod tests { health_url ); } - - #[test] - fn test_is_base_url() { - let url = "http://localhost"; - assert!(is_base_url(url)); - - let url = "https://example-detector.route.example.com/"; - assert!(is_base_url(url)); - - let url = "https://example-detector.route.example.com"; - assert!(is_base_url(url)); - - let url = "https://example-detector.route.example.com/api/v1/text/contents"; - assert!(!is_base_url(url)); - - let url = "https://example-detector.route.example.com/api/v1/"; - assert!(!is_base_url(url)); - } } diff --git a/src/config.rs b/src/config.rs index 7cef45d8..c997fd51 100644 --- a/src/config.rs +++ b/src/config.rs @@ -23,7 +23,7 @@ use std::{ use serde::Deserialize; use tracing::{debug, error, info, warn}; -use crate::clients::{chunker::DEFAULT_MODEL_ID, http::is_base_url}; +use crate::clients::{chunker::DEFAULT_MODEL_ID, is_valid_hostname}; // Placeholder to add default allowed headers const DEFAULT_ALLOWED_HEADERS: &[&str] = &[]; @@ -266,10 +266,9 @@ impl OrchestratorConfig { // Detector configs are valid for (detector_id, detector) in &self.detectors { // Hostname is valid - if !is_base_url(&detector.service.hostname) { + if !is_valid_hostname(&detector.service.hostname) { return Err(Error::InvalidHostname(format!( - "detector `{detector_id}` has an invalid hostname; \ - must be a base url, e.g. `https://service.route.example.com" + "detector `{detector_id}` has an invalid hostname" ))); } // Chunker is valid @@ -286,6 +285,18 @@ impl OrchestratorConfig { } } + // Chunker config is valid + if let Some(chunkers) = &self.chunkers { + for (chunker_id, chunker) in chunkers { + // Hostname is valid + if !is_valid_hostname(&chunker.service.hostname) { + return Err(Error::InvalidHostname(format!( + "chunker `{chunker_id}` has an invalid hostname" + ))); + } + } + } + // Generation config is valid if let Some(generation) = &self.generation { // Provider is valid @@ -297,6 +308,12 @@ impl OrchestratorConfig { "`generation` requires `tgis` or `nlp` provider".into(), )); } + // Hostname is valid + if !is_valid_hostname(&generation.service.hostname) { + return Err(Error::InvalidHostname( + "`generation` has an invalid hostname".into(), + )); + } } // Chat generation config is valid @@ -308,11 +325,9 @@ impl OrchestratorConfig { )); } // Hostname is valid - if !is_base_url(&chat_generation.service.hostname) { + if !is_valid_hostname(&chat_generation.service.hostname) { return Err(Error::InvalidHostname( - "`chat_generation` has an invalid hostname; \ - must be a base url, e.g. `https://service.route.example.com" - .into(), + "`chat_generation` has an invalid hostname".into(), )); } } diff --git a/tests/test.config.yaml b/tests/test.config.yaml index 773b85b7..91749356 100644 --- a/tests/test.config.yaml +++ b/tests/test.config.yaml @@ -13,7 +13,7 @@ detectors: test_detector: type: text_contents service: - hostname: https://localhost + hostname: localhost port: 8000 chunker_id: test_chunker default_threshold: 0.5 From 3b9d4c33fa30d1a68b4210fed1079cfa8f9e157b Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 2 Oct 2024 19:45:06 -0300 Subject: [PATCH 03/50] Add detector type ADR Signed-off-by: Mateus Devino --- docs/architecture/adrs/006-detector-type.md | 45 +++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 docs/architecture/adrs/006-detector-type.md diff --git a/docs/architecture/adrs/006-detector-type.md b/docs/architecture/adrs/006-detector-type.md new file mode 100644 index 00000000..cb3a38da --- /dev/null +++ b/docs/architecture/adrs/006-detector-type.md @@ -0,0 +1,45 @@ +# ADR 006: Detector Type + +This ADR documents the decision of adding the `type` parameter for detectors in the orchestrator config. + +## Motivation + +The guardrails orchestrator interfaces with different types of detectors. +Detectors of a given are type are compatible with only a subset of orchestrator endpoints. +In order to reduce changes of misconfiguration, we need a way to map detectors to be used only with compatible endpoints. + + +## Decision + +We decided to add the `type` parameter to the detectors configuration. +Possible values are `text_contents`, `text_context_chat`, `text_generation` and `text_context_doc`. +Below is an example of detector configuration. + +```yaml +detectors: + my_detector: + type: text_contents # Options: text_contents, text_context_chat, text_context_doc, text_generation + service: + hostname: my-detector.com + port: 8080 + tls: my_certs + chunker_id: my_chunker + default_threshold: 0.5 +``` + +## Consequences + +1. Reduced misconfiguration risk. +2. Future logic can be implemented for detectors of a particular type. +3. `hostname` no longer needs the full URL, but only the actual hostname. +4. If `tls` is provided, the `https` protocol is used. `http`, otherwise. +5. Not including `type` results in a configuration validation error on orchestrator startup. +6. Detector endpoints are automatically configured based on `type` as follows: + * `text_contents` -> `/api/v1/text/contents` + * `text_context_chat` -> `/api/v1/text/context/chat` + * `text_context_doc` -> `/api/v1/text/context/doc` + * `text_generation` -> `/api/v1/text/generation` + +## Status + +Accepted \ No newline at end of file From 94950577e0f5768e75453be2b6b93034cc705ff7 Mon Sep 17 00:00:00 2001 From: declark1 Date: Thu, 3 Oct 2024 11:44:31 -0700 Subject: [PATCH 04/50] Rebase and add header passthrough to detectors, update tests Signed-off-by: declark1 --- src/clients/detector/text_contents.rs | 4 +++- src/clients/detector/text_context_doc.rs | 6 ++++-- src/clients/detector/text_generation.rs | 4 +++- src/config.rs | 2 ++ src/orchestrator/unary.rs | 5 ++--- 5 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/clients/detector/text_contents.rs b/src/clients/detector/text_contents.rs index 13b85c25..50a8013d 100644 --- a/src/clients/detector/text_contents.rs +++ b/src/clients/detector/text_contents.rs @@ -1,5 +1,5 @@ use async_trait::async_trait; -use hyper::StatusCode; +use hyper::{HeaderMap, StatusCode}; use serde::{Deserialize, Serialize}; use super::{DetectorError, DETECTOR_ID_HEADER_NAME}; @@ -24,6 +24,7 @@ impl TextContentsDetectorClient { &self, model_id: &str, request: ContentAnalysisRequest, + headers: HeaderMap, ) -> Result>, Error> { let url = self .client @@ -33,6 +34,7 @@ impl TextContentsDetectorClient { let response = self .client .post(url) + .headers(headers) .header(DETECTOR_ID_HEADER_NAME, model_id) .json(&request) .send() diff --git a/src/clients/detector/text_context_doc.rs b/src/clients/detector/text_context_doc.rs index 3858c4cc..aca518f7 100644 --- a/src/clients/detector/text_context_doc.rs +++ b/src/clients/detector/text_context_doc.rs @@ -1,5 +1,5 @@ use async_trait::async_trait; -use hyper::StatusCode; +use hyper::{HeaderMap, StatusCode}; use serde::{Deserialize, Serialize}; use super::{DetectorError, DETECTOR_ID_HEADER_NAME}; @@ -21,10 +21,11 @@ impl TextContextDocDetectorClient { Self { client } } - pub async fn text_context_docs( + pub async fn text_context_doc( &self, model_id: &str, request: ContextDocsDetectionRequest, + headers: HeaderMap, ) -> Result, Error> { let url = self .client @@ -34,6 +35,7 @@ impl TextContextDocDetectorClient { let response = self .client .post(url) + .headers(headers) .header(DETECTOR_ID_HEADER_NAME, model_id) .json(&request) .send() diff --git a/src/clients/detector/text_generation.rs b/src/clients/detector/text_generation.rs index 7b55893d..7fcfe987 100644 --- a/src/clients/detector/text_generation.rs +++ b/src/clients/detector/text_generation.rs @@ -1,5 +1,5 @@ use async_trait::async_trait; -use hyper::StatusCode; +use hyper::{HeaderMap, StatusCode}; use serde::Serialize; use super::{DetectorError, DETECTOR_ID_HEADER_NAME}; @@ -25,6 +25,7 @@ impl TextGenerationDetectorClient { &self, model_id: &str, request: GenerationDetectionRequest, + headers: HeaderMap, ) -> Result, Error> { let url = self .client @@ -34,6 +35,7 @@ impl TextGenerationDetectorClient { let response = self .client .post(url) + .headers(headers) .header(DETECTOR_ID_HEADER_NAME, model_id) .json(&request) .send() diff --git a/src/config.rs b/src/config.rs index c997fd51..ed002804 100644 --- a/src/config.rs +++ b/src/config.rs @@ -662,6 +662,7 @@ chunkers: port: 9000 detectors: hap: + type: text_contents service: hostname: localhost port: 9000 @@ -689,6 +690,7 @@ chunkers: port: 9000 detectors: hap: + type: text_contents service: hostname: localhost port: 9000 diff --git a/src/orchestrator/unary.rs b/src/orchestrator/unary.rs index 149695e5..a16b0e93 100644 --- a/src/orchestrator/unary.rs +++ b/src/orchestrator/unary.rs @@ -737,7 +737,7 @@ pub async fn detect_for_context( .get_as::(&detector_id) .unwrap(); let response = client - .text_context_docs(&detector_id, request, headers) + .text_context_doc(&detector_id, request, headers) .await .map(|results| { results @@ -908,8 +908,7 @@ mod tests { }; // Construct a behavior for the mock object - let headers = HeaderMap::new(); - faux::when!(tgis_client.generate(expected_generate_req_args, headers)) + faux::when!(tgis_client.generate(expected_generate_req_args, HeaderMap::new())) .once() // TODO: Add with_args .then_return(Ok(client_generation_response)); From c27e37a9fbb6bad65e71cc3a10a5197ff6284a49 Mon Sep 17 00:00:00 2001 From: declark1 Date: Thu, 3 Oct 2024 13:05:54 -0700 Subject: [PATCH 05/50] Apply health check related tweaks Signed-off-by: declark1 --- Cargo.lock | 11 +++ Cargo.toml | 1 + src/clients.rs | 169 ++++++++++++++++------------------------- src/clients/chunker.rs | 8 +- src/clients/errors.rs | 42 ++++------ src/clients/http.rs | 35 +++------ src/clients/nlp.rs | 8 +- src/clients/tgis.rs | 8 +- src/health.rs | 156 +++++++++++++------------------------ src/models.rs | 4 +- src/orchestrator.rs | 16 ++-- src/server.rs | 4 +- 12 files changed, 178 insertions(+), 284 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f33e376d..a6b64404 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -604,6 +604,7 @@ dependencies = [ "faux", "futures", "ginepro", + "http-serde", "hyper", "hyper-util", "indexmap 2.5.0", @@ -903,6 +904,16 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-serde" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f056c8559e3757392c8d091e796416e4649d8e49e88b8d76df6c002f05027fd" +dependencies = [ + "http 1.1.0", + "serde", +] + [[package]] name = "httparse" version = "1.9.5" diff --git a/Cargo.toml b/Cargo.toml index c1079f6b..272713c4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ uuid = { version = "1.10.0", features = ["v4", "fast-rng"] } async-trait = "0.1.81" async-stream = "0.3.5" indexmap = { version = "2.5.0", features = ["serde"] } +http-serde = "2.1.1" [build-dependencies] tonic-build = "0.12.1" diff --git a/src/clients.rs b/src/clients.rs index 93da58a9..d85c6b6b 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -35,7 +35,7 @@ use crate::{ }; pub mod errors; -pub use errors::{ClientCode, Error}; +pub use errors::Error; pub mod http; pub use http::HttpClient; @@ -334,6 +334,7 @@ pub fn is_valid_hostname(hostname: &str) -> bool { #[cfg(test)] mod tests { + use errors::grpc_to_http_code; use hyper::{http, StatusCode}; use reqwest::Response; @@ -373,83 +374,69 @@ mod tests { // READY responses from HTTP 200 OK with or without reason let response = [ (StatusCode::OK, r#"{}"#), - (StatusCode::OK, r#"{ "health_status": "HEALTHY" }"#), + (StatusCode::OK, r#"{ "status": "HEALTHY" }"#), + (StatusCode::OK, r#"{ "status": "meaningless status" }"#), ( StatusCode::OK, - r#"{ "health_status": "meaningless status" }"#, - ), - ( - StatusCode::OK, - r#"{ "health_status": "HEALTHY", "reason": "needless reason" }"#, + r#"{ "status": "HEALTHY", "reason": "needless reason" }"#, ), ]; for (status, body) in response.iter() { let response = mock_http_response(*status, body).await; let result = HttpClient::http_response_to_health_check_result(response).await; - assert_eq!(result.health_status, HealthStatus::Healthy); - assert_eq!(result.response_code, ClientCode::Http(StatusCode::OK)); + assert_eq!(result.status, HealthStatus::Healthy); + assert_eq!(result.code, StatusCode::OK); assert_eq!(result.reason, None); let serialized = serde_json::to_string(&result).unwrap(); - assert_eq!(serialized, r#""HEALTHY""#); + assert_eq!(serialized, r#"{"status":"HEALTHY"}"#); } // NOT_READY response from HTTP 200 OK without reason - let response = - mock_http_response(StatusCode::OK, r#"{ "health_status": "UNHEALTHY" }"#).await; + let response = mock_http_response(StatusCode::OK, r#"{ "status": "UNHEALTHY" }"#).await; let result = HttpClient::http_response_to_health_check_result(response).await; - assert_eq!(result.health_status, HealthStatus::Unhealthy); - assert_eq!(result.response_code, ClientCode::Http(StatusCode::OK)); + assert_eq!(result.status, HealthStatus::Unhealthy); + assert_eq!(result.code, StatusCode::OK); assert_eq!(result.reason, None); let serialized = serde_json::to_string(&result).unwrap(); - assert_eq!( - serialized, - r#"{"health_status":"UNHEALTHY","response_code":"HTTP 200 OK"}"# - ); + assert_eq!(serialized, r#"{"status":"UNHEALTHY"}"#); // UNKNOWN response from HTTP 200 OK without reason - let response = - mock_http_response(StatusCode::OK, r#"{ "health_status": "UNKNOWN" }"#).await; + let response = mock_http_response(StatusCode::OK, r#"{ "status": "UNKNOWN" }"#).await; let result = HttpClient::http_response_to_health_check_result(response).await; - assert_eq!(result.health_status, HealthStatus::Unknown); - assert_eq!(result.response_code, ClientCode::Http(StatusCode::OK)); + assert_eq!(result.status, HealthStatus::Unknown); + assert_eq!(result.code, StatusCode::OK); assert_eq!(result.reason, None); let serialized = serde_json::to_string(&result).unwrap(); - assert_eq!( - serialized, - r#"{"health_status":"UNKNOWN","response_code":"HTTP 200 OK"}"# - ); + assert_eq!(serialized, r#"{"status":"UNKNOWN"}"#); // NOT_READY response from HTTP 200 OK with reason let response = mock_http_response( StatusCode::OK, - r#"{ "health_status": "UNHEALTHY", "reason": "some reason" }"#, + r#"{"status": "UNHEALTHY", "reason": "some reason" }"#, ) .await; let result = HttpClient::http_response_to_health_check_result(response).await; - assert_eq!(result.health_status, HealthStatus::Unhealthy); - assert_eq!(result.response_code, ClientCode::Http(StatusCode::OK)); + assert_eq!(result.status, HealthStatus::Unhealthy); + assert_eq!(result.code, StatusCode::OK); assert_eq!(result.reason, Some("some reason".to_string())); let serialized = serde_json::to_string(&result).unwrap(); assert_eq!( serialized, - r#"{"health_status":"UNHEALTHY","response_code":"HTTP 200 OK","reason":"some reason"}"# + r#"{"status":"UNHEALTHY","reason":"some reason"}"# ); // UNKNOWN response from HTTP 200 OK with reason let response = mock_http_response( StatusCode::OK, - r#"{ "health_status": "UNKNOWN", "reason": "some reason" }"#, + r#"{ "status": "UNKNOWN", "reason": "some reason" }"#, ) .await; let result = HttpClient::http_response_to_health_check_result(response).await; - assert_eq!(result.health_status, HealthStatus::Unknown); - assert_eq!(result.response_code, ClientCode::Http(StatusCode::OK)); + assert_eq!(result.status, HealthStatus::Unknown); + assert_eq!(result.code, StatusCode::OK); assert_eq!(result.reason, Some("some reason".to_string())); let serialized = serde_json::to_string(&result).unwrap(); - assert_eq!( - serialized, - r#"{"health_status":"UNKNOWN","response_code":"HTTP 200 OK","reason":"some reason"}"# - ); + assert_eq!(serialized, r#"{"status":"UNKNOWN","reason":"some reason"}"#); // NOT_READY response from HTTP 503 SERVICE UNAVAILABLE with reason let response = mock_http_response( @@ -458,16 +445,13 @@ mod tests { ) .await; let result = HttpClient::http_response_to_health_check_result(response).await; - assert_eq!(result.health_status, HealthStatus::Unhealthy); - assert_eq!( - result.response_code, - ClientCode::Http(StatusCode::SERVICE_UNAVAILABLE) - ); - assert_eq!(result.reason, Some(r#"HTTP status server error (503 Service Unavailable) for url (http://no.url.provided.local/): { "message": "some error message" }"#.to_string())); + assert_eq!(result.status, HealthStatus::Unhealthy); + assert_eq!(result.code, StatusCode::SERVICE_UNAVAILABLE); + assert_eq!(result.reason, Some("Service Unavailable".to_string())); let serialized = serde_json::to_string(&result).unwrap(); assert_eq!( serialized, - r#"{"health_status":"UNHEALTHY","response_code":"HTTP 503 Service Unavailable","reason":"HTTP status server error (503 Service Unavailable) for url (http://no.url.provided.local/): { \"message\": \"some error message\" }"}"# + r#"{"status":"UNHEALTHY","code":503,"reason":"Service Unavailable"}"# ); // UNKNOWN response from HTTP 404 NOT FOUND with reason @@ -477,46 +461,37 @@ mod tests { ) .await; let result = HttpClient::http_response_to_health_check_result(response).await; - assert_eq!(result.health_status, HealthStatus::Unknown); - assert_eq!( - result.response_code, - ClientCode::Http(StatusCode::NOT_FOUND) - ); - assert_eq!(result.reason, Some(r#"HTTP status client error (404 Not Found) for url (http://no.url.provided.local/): { "message": "service not found" }"#.to_string())); + assert_eq!(result.status, HealthStatus::Unknown); + assert_eq!(result.code, StatusCode::NOT_FOUND); + assert_eq!(result.reason, Some("Not Found".to_string())); let serialized = serde_json::to_string(&result).unwrap(); assert_eq!( serialized, - r#"{"health_status":"UNKNOWN","response_code":"HTTP 404 Not Found","reason":"HTTP status client error (404 Not Found) for url (http://no.url.provided.local/): { \"message\": \"service not found\" }"}"# + r#"{"status":"UNKNOWN","code":404,"reason":"Not Found"}"# ); // NOT_READY response from HTTP 500 INTERNAL SERVER ERROR without reason let response = mock_http_response(StatusCode::INTERNAL_SERVER_ERROR, r#""#).await; let result = HttpClient::http_response_to_health_check_result(response).await; - assert_eq!(result.health_status, HealthStatus::Unhealthy); - assert_eq!( - result.response_code, - ClientCode::Http(StatusCode::INTERNAL_SERVER_ERROR) - ); - assert_eq!(result.reason, Some("HTTP status server error (500 Internal Server Error) for url (http://no.url.provided.local/)".to_string())); + assert_eq!(result.status, HealthStatus::Unhealthy); + assert_eq!(result.code, StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!(result.reason, Some("Internal Server Error".to_string())); let serialized = serde_json::to_string(&result).unwrap(); assert_eq!( serialized, - r#"{"health_status":"UNHEALTHY","response_code":"HTTP 500 Internal Server Error","reason":"HTTP status server error (500 Internal Server Error) for url (http://no.url.provided.local/)"}"# + r#"{"status":"UNHEALTHY","code":500,"reason":"Internal Server Error"}"# ); // UNKNOWN response from HTTP 400 BAD REQUEST without reason let response = mock_http_response(StatusCode::BAD_REQUEST, r#""#).await; let result = HttpClient::http_response_to_health_check_result(response).await; - assert_eq!(result.health_status, HealthStatus::Unknown); - assert_eq!( - result.response_code, - ClientCode::Http(StatusCode::BAD_REQUEST) - ); - assert_eq!(result.reason, Some("HTTP status client error (400 Bad Request) for url (http://no.url.provided.local/)".to_string())); + assert_eq!(result.status, HealthStatus::Unknown); + assert_eq!(result.code, StatusCode::BAD_REQUEST); + assert_eq!(result.reason, Some("Bad Request".to_string())); let serialized = serde_json::to_string(&result).unwrap(); assert_eq!( serialized, - r#"{"health_status":"UNKNOWN","response_code":"HTTP 400 Bad Request","reason":"HTTP status client error (400 Bad Request) for url (http://no.url.provided.local/)"}"# + r#"{"status":"UNKNOWN","code":400,"reason":"Bad Request"}"# ); } @@ -525,59 +500,47 @@ mod tests { // READY responses from gRPC 0 OK from serving status 1 SERVING let response = mock_grpc_response(Some(ServingStatus::Serving as i32), None).await; let result = HealthCheckResult::from(response); - assert_eq!(result.health_status, HealthStatus::Healthy); - assert_eq!(result.response_code, ClientCode::Grpc(tonic::Code::Ok)); + assert_eq!(result.status, HealthStatus::Healthy); + assert_eq!(result.code, grpc_to_http_code(tonic::Code::Ok)); assert_eq!(result.reason, None); let serialized = serde_json::to_string(&result).unwrap(); - assert_eq!(serialized, r#""HEALTHY""#); + assert_eq!(serialized, r#"{"status":"HEALTHY"}"#); // NOT_READY response from gRPC 0 OK form serving status 2 NOT_SERVING let response = mock_grpc_response(Some(ServingStatus::NotServing as i32), None).await; let result = HealthCheckResult::from(response); - assert_eq!(result.health_status, HealthStatus::Unhealthy); - assert_eq!(result.response_code, ClientCode::Grpc(tonic::Code::Ok)); - assert_eq!( - result.reason, - Some("from gRPC health check serving status: NOT_SERVING".to_string()) - ); + assert_eq!(result.status, HealthStatus::Unhealthy); + assert_eq!(result.code, grpc_to_http_code(tonic::Code::Ok)); + assert_eq!(result.reason, Some("NOT_SERVING".to_string())); let serialized = serde_json::to_string(&result).unwrap(); assert_eq!( serialized, - r#"{"health_status":"UNHEALTHY","response_code":"gRPC Ok The operation completed successfully","reason":"from gRPC health check serving status: NOT_SERVING"}"# + r#"{"status":"UNHEALTHY","reason":"NOT_SERVING"}"# ); // UNKNOWN response from gRPC 0 OK from serving status 0 UNKNOWN let response = mock_grpc_response(Some(ServingStatus::Unknown as i32), None).await; let result = HealthCheckResult::from(response); - assert_eq!(result.health_status, HealthStatus::Unknown); - assert_eq!(result.response_code, ClientCode::Grpc(tonic::Code::Ok)); - assert_eq!( - result.reason, - Some("from gRPC health check serving status: UNKNOWN".to_string()) - ); + assert_eq!(result.status, HealthStatus::Unknown); + assert_eq!(result.code, grpc_to_http_code(tonic::Code::Ok)); + assert_eq!(result.reason, Some("UNKNOWN".to_string())); let serialized = serde_json::to_string(&result).unwrap(); - assert_eq!( - serialized, - r#"{"health_status":"UNKNOWN","response_code":"gRPC Ok The operation completed successfully","reason":"from gRPC health check serving status: UNKNOWN"}"# - ); + assert_eq!(serialized, r#"{"status":"UNKNOWN","reason":"UNKNOWN"}"#); // UNKNOWN response from gRPC 0 OK from serving status 3 SERVICE_UNKNOWN let response = mock_grpc_response(Some(ServingStatus::ServiceUnknown as i32), None).await; let result = HealthCheckResult::from(response); - assert_eq!(result.health_status, HealthStatus::Unknown); - assert_eq!(result.response_code, ClientCode::Grpc(tonic::Code::Ok)); - assert_eq!( - result.reason, - Some("from gRPC health check serving status: SERVICE_UNKNOWN".to_string()) - ); + assert_eq!(result.status, HealthStatus::Unknown); + assert_eq!(result.code, grpc_to_http_code(tonic::Code::Ok)); + assert_eq!(result.reason, Some("SERVICE_UNKNOWN".to_string())); let serialized = serde_json::to_string(&result).unwrap(); assert_eq!( serialized, - r#"{"health_status":"UNKNOWN","response_code":"gRPC Ok The operation completed successfully","reason":"from gRPC health check serving status: SERVICE_UNKNOWN"}"# + r#"{"status":"UNKNOWN","reason":"SERVICE_UNKNOWN"}"# ); // UNKNOWN response from other gRPC error codes (covering main ones) - let response_codes = [ + let codes = [ tonic::Code::InvalidArgument, tonic::Code::Internal, tonic::Code::NotFound, @@ -586,23 +549,21 @@ mod tests { tonic::Code::PermissionDenied, tonic::Code::Unavailable, ]; - for code in response_codes.iter() { - let status = tonic::Status::new(*code, "some error message"); + for code in codes.into_iter() { + let status = tonic::Status::new(code, "some error message"); + let code = grpc_to_http_code(code); let response = mock_grpc_response(None, Some(status.clone())).await; let result = HealthCheckResult::from(response); - assert_eq!(result.health_status, HealthStatus::Unknown); - assert_eq!(result.response_code, ClientCode::Grpc(*code)); - assert_eq!( - result.reason, - Some(format!("gRPC health check failed: {}", status.clone())) - ); + assert_eq!(result.status, HealthStatus::Unknown); + assert_eq!(result.code, code); + assert_eq!(result.reason, Some("some error message".to_string())); let serialized = serde_json::to_string(&result).unwrap(); assert_eq!( serialized, format!( - r#"{{"health_status":"UNKNOWN","response_code":"gRPC {:?} {}","reason":"gRPC health check failed: status: {:?}, message: \"some error message\", details: [], metadata: MetadataMap {{ headers: {{}} }}"}}"#, - code, code, code - ) + r#"{{"status":"UNKNOWN","code":{},"reason":"some error message"}}"#, + code.as_u16() + ), ); } } diff --git a/src/clients/chunker.rs b/src/clients/chunker.rs index e1e097a3..e0113b97 100644 --- a/src/clients/chunker.rs +++ b/src/clients/chunker.rs @@ -25,7 +25,7 @@ use tokio_stream::wrappers::ReceiverStream; use tonic::{Code, Request, Response, Status, Streaming}; use tracing::info; -use super::{BoxStream, Client, ClientCode, Error}; +use super::{errors::grpc_to_http_code, BoxStream, Client, Error}; use crate::{ health::{HealthCheckResult, HealthStatus}, pb::{ @@ -134,14 +134,14 @@ impl Client for ChunkerClient { } Err(status) => status.code(), }; - let health_status = if matches!(code, Code::Ok) { + let status = if matches!(code, Code::Ok) { HealthStatus::Healthy } else { HealthStatus::Unhealthy }; HealthCheckResult { - health_status, - response_code: ClientCode::Grpc(code), + status, + code: grpc_to_http_code(code), reason: None, } } diff --git a/src/clients/errors.rs b/src/clients/errors.rs index 306d0e78..e8f638a8 100644 --- a/src/clients/errors.rs +++ b/src/clients/errors.rs @@ -54,38 +54,26 @@ impl From for Error { impl From for Error { fn from(value: tonic::Status) -> Self { - use tonic::Code::*; - // Return equivalent http status code for grpc status code - let code = match value.code() { - InvalidArgument => StatusCode::BAD_REQUEST, - Internal => StatusCode::INTERNAL_SERVER_ERROR, - NotFound => StatusCode::NOT_FOUND, - DeadlineExceeded => StatusCode::REQUEST_TIMEOUT, - Unimplemented => StatusCode::NOT_IMPLEMENTED, - Unauthenticated => StatusCode::UNAUTHORIZED, - PermissionDenied => StatusCode::FORBIDDEN, - Unavailable => StatusCode::SERVICE_UNAVAILABLE, - Ok => StatusCode::OK, - _ => StatusCode::INTERNAL_SERVER_ERROR, - }; Self::Grpc { - code, + code: grpc_to_http_code(value.code()), message: value.message().to_string(), } } } -#[derive(Debug, Clone, PartialEq)] -pub enum ClientCode { - Http(StatusCode), - Grpc(tonic::Code), -} - -impl std::fmt::Display for ClientCode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ClientCode::Http(code) => write!(f, "HTTP {}", code), - ClientCode::Grpc(code) => write!(f, "gRPC {:?} {}", code, code), - } +/// Returns equivalent http status code for grpc status code +pub fn grpc_to_http_code(value: tonic::Code) -> StatusCode { + use tonic::Code::*; + match value { + InvalidArgument => StatusCode::BAD_REQUEST, + Internal => StatusCode::INTERNAL_SERVER_ERROR, + NotFound => StatusCode::NOT_FOUND, + DeadlineExceeded => StatusCode::REQUEST_TIMEOUT, + Unimplemented => StatusCode::NOT_IMPLEMENTED, + Unauthenticated => StatusCode::UNAUTHORIZED, + PermissionDenied => StatusCode::FORBIDDEN, + Unavailable => StatusCode::SERVICE_UNAVAILABLE, + Ok => StatusCode::OK, + _ => StatusCode::INTERNAL_SERVER_ERROR, } } diff --git a/src/clients/http.rs b/src/clients/http.rs index f6d1180c..fdd6aef6 100644 --- a/src/clients/http.rs +++ b/src/clients/http.rs @@ -3,7 +3,6 @@ use reqwest::Response; use tracing::error; use url::Url; -use super::ClientCode; use crate::health::{HealthCheckResult, HealthStatus, OptionalHealthCheckResponseBody}; #[derive(Clone)] @@ -37,9 +36,9 @@ impl HttpClient { if let Ok(body) = response.json::().await { // If the service provided a body, we only anticipate a minimal health status and optional reason. HealthCheckResult { - health_status: body.health_status.clone(), - response_code: ClientCode::Http(StatusCode::OK), - reason: match body.health_status { + status: body.status.clone(), + code: StatusCode::OK, + reason: match body.status { HealthStatus::Healthy => None, _ => body.reason, }, @@ -47,8 +46,8 @@ impl HttpClient { } else { // If the service did not provide a body, we assume it is healthy. HealthCheckResult { - health_status: HealthStatus::Healthy, - response_code: ClientCode::Http(StatusCode::OK), + status: HealthStatus::Healthy, + code: StatusCode::OK, reason: None, } } @@ -59,7 +58,7 @@ impl HttpClient { // Regardless we can't be certain, so the reason is also provided. // TODO: We will likely circle back to re-evaluate this logic in the future // when we know more about how the client health results will be used. - health_status: if response.status().as_u16() >= 500 + status: if response.status().as_u16() >= 500 && response.status().as_u16() < 600 { HealthStatus::Unhealthy @@ -74,30 +73,16 @@ impl HttpClient { ); HealthStatus::Unknown }, - response_code: ClientCode::Http(response.status()), - reason: Some(format!( - "{}{}", - response.error_for_status_ref().unwrap_err(), - response - .text() - .await - .map(|s| if s.is_empty() { - "".to_string() - } else { - format!(": {}", s) - }) - .unwrap_or("".to_string()) - )), + code: response.status(), + reason: response.status().canonical_reason().map(|v| v.to_string()), } } } Err(e) => { error!("error checking health: {}", e); HealthCheckResult { - health_status: HealthStatus::Unknown, - response_code: ClientCode::Http( - e.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), - ), + status: HealthStatus::Unknown, + code: e.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), reason: Some(e.to_string()), } } diff --git a/src/clients/nlp.rs b/src/clients/nlp.rs index 7987b3dc..c40f9490 100644 --- a/src/clients/nlp.rs +++ b/src/clients/nlp.rs @@ -21,7 +21,7 @@ use futures::{StreamExt, TryStreamExt}; use ginepro::LoadBalancedChannel; use tonic::{metadata::MetadataMap, Code, Request}; -use super::{BoxStream, Client, ClientCode, Error}; +use super::{errors::grpc_to_http_code, BoxStream, Client, Error}; use crate::{ health::{HealthCheckResult, HealthStatus}, pb::{ @@ -137,14 +137,14 @@ impl Client for NlpClient { } Err(status) => status.code(), }; - let health_status = if matches!(code, Code::Ok) { + let status = if matches!(code, Code::Ok) { HealthStatus::Healthy } else { HealthStatus::Unhealthy }; HealthCheckResult { - health_status, - response_code: ClientCode::Grpc(code), + status, + code: grpc_to_http_code(code), reason: None, } } diff --git a/src/clients/tgis.rs b/src/clients/tgis.rs index 040e61e9..1a779b09 100644 --- a/src/clients/tgis.rs +++ b/src/clients/tgis.rs @@ -20,7 +20,7 @@ use futures::{StreamExt, TryStreamExt}; use ginepro::LoadBalancedChannel; use tonic::Code; -use super::{BoxStream, Client, ClientCode, Error}; +use super::{errors::grpc_to_http_code, BoxStream, Client, Error}; use crate::{ health::{HealthCheckResult, HealthStatus}, pb::fmaas::{ @@ -102,14 +102,14 @@ impl Client for TgisClient { } Err(status) => status.code(), }; - let health_status = if matches!(code, Code::Ok) { + let status = if matches!(code, Code::Ok) { HealthStatus::Healthy } else { HealthStatus::Unhealthy }; HealthCheckResult { - health_status, - response_code: ClientCode::Grpc(code), + status, + code: grpc_to_http_code(code), reason: None, } } diff --git a/src/health.rs b/src/health.rs index caae61aa..8ee56dac 100644 --- a/src/health.rs +++ b/src/health.rs @@ -1,11 +1,12 @@ use std::{collections::HashMap, fmt::Display}; use axum::http::StatusCode; -use serde::{ser::SerializeStruct, Deserialize, Serialize}; -use tonic::Code; -use tracing::{error, warn}; +use serde::{Deserialize, Serialize}; -use crate::{clients::ClientCode, pb::grpc::health::v1::HealthCheckResponse}; +use crate::{ + clients::errors::grpc_to_http_code, + pb::grpc::health::v1::{health_check_response::ServingStatus, HealthCheckResponse}, +}; /// Health status determined for or returned by a client service. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -31,11 +32,10 @@ impl Display for HealthStatus { impl From for HealthStatus { fn from(value: HealthCheckResponse) -> Self { - // NOTE: gRPC Health v1 status codes: 0 = UNKNOWN, 1 = SERVING, 2 = NOT_SERVING, 3 = SERVICE_UNKNOWN - match value.status { - 1 => Self::Healthy, - 2 => Self::Unhealthy, - _ => Self::Unknown, + match value.status() { + ServingStatus::Serving => Self::Healthy, + ServingStatus::NotServing => Self::Unhealthy, + ServingStatus::Unknown | ServingStatus::ServiceUnknown => Self::Unknown, } } } @@ -43,38 +43,19 @@ impl From for HealthStatus { impl From for HealthStatus { fn from(code: StatusCode) -> Self { match code.as_u16() { - 200 => Self::Healthy, - 201..=299 => { - warn!( - "Unexpected HTTP successful health check response status code: {}", - code - ); - Self::Healthy - } - 503 => Self::Unhealthy, - 500..=502 | 504..=599 => { - warn!( - "Unexpected HTTP server error health check response status code: {}", - code - ); - Self::Unhealthy - } - _ => { - warn!( - "Unexpected HTTP client error health check response status code: {}", - code - ); - Self::Unknown - } + 200..=299 => Self::Healthy, + 500..=599 => Self::Unhealthy, + _ => Self::Unknown, } } } -/// Holds health check results for all clients. +/// A cache to hold the latest health check results for each client service. +/// Orchestrator has a reference-counted mutex-protected instance of this cache. #[derive(Debug, Clone, Default, Serialize)] -pub struct ClientHealth(HashMap); +pub struct HealthCheckCache(HashMap); -impl ClientHealth { +impl HealthCheckCache { pub fn new() -> Self { Self(HashMap::new()) } @@ -83,15 +64,16 @@ impl ClientHealth { Self(HashMap::with_capacity(capacity)) } + /// Returns `true` if all services are healthy or unknown. pub fn healthy(&self) -> bool { !self .0 .iter() - .any(|(_, value)| matches!(value.health_status, HealthStatus::Unhealthy)) + .any(|(_, value)| matches!(value.status, HealthStatus::Unhealthy)) } } -impl std::ops::Deref for ClientHealth { +impl std::ops::Deref for HealthCheckCache { type Target = HashMap; fn deref(&self) -> &Self::Target { @@ -99,82 +81,48 @@ impl std::ops::Deref for ClientHealth { } } -impl std::ops::DerefMut for ClientHealth { +impl std::ops::DerefMut for HealthCheckCache { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } -/// Result of a health check request. -#[derive(Debug, Clone)] -pub struct HealthCheckResult { - /// Overall health status of client service. - /// `HEALTHY`, `UNHEALTHY`, or `UNKNOWN`. - pub health_status: HealthStatus, - /// Response code of the latest health check request. - /// This should be omitted on serialization if the health check was successful (when the response is `HTTP 200 OK` or `gRPC 0 OK`). - pub response_code: ClientCode, - /// Optional reason for the health check result status being `UNHEALTHY` or `UNKNOWN`. - /// May be omitted overall if the health check was successful. - pub reason: Option, +impl Display for HealthCheckCache { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", serde_json::to_string_pretty(self).unwrap()) + } } -impl HealthCheckResult { - pub fn reason_from_health_check_response(response: &HealthCheckResponse) -> Option { - match response.status { - 0 => Some("from gRPC health check serving status: UNKNOWN".to_string()), - 1 => None, - 2 => Some("from gRPC health check serving status: NOT_SERVING".to_string()), - 3 => Some("from gRPC health check serving status: SERVICE_UNKNOWN".to_string()), - _ => { - error!( - "Unexpected gRPC health check serving status: {}", - response.status - ); - Some(format!( - "Unexpected gRPC health check serving status: {}", - response.status - )) - } +impl HealthCheckResponse { + pub fn reason(&self) -> Option { + let status = self.status(); + match status { + ServingStatus::Serving => None, + _ => Some(status.as_str_name().to_string()), } } } -impl Serialize for HealthCheckResult { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - match self.health_status { - HealthStatus::Healthy => self.health_status.serialize(serializer), - _ => match &self.reason { - Some(reason) => { - let mut state = serializer.serialize_struct("HealthCheckResult", 3)?; - state.serialize_field("health_status", &self.health_status)?; - state.serialize_field("response_code", &self.response_code.to_string())?; - state.serialize_field("reason", reason)?; - state.end() - } - None => { - let mut state = serializer.serialize_struct("HealthCheckResult", 2)?; - state.serialize_field("health_status", &self.health_status)?; - state.serialize_field("response_code", &self.response_code.to_string())?; - state.end() - } - }, - } - } +/// Result of a health check request. +#[derive(Debug, Clone, Serialize)] +pub struct HealthCheckResult { + /// Overall health status of client service. + pub status: HealthStatus, + /// Response code of the latest health check request. + #[serde( + with = "http_serde::status_code", + skip_serializing_if = "StatusCode::is_success" + )] + pub code: StatusCode, + #[serde(skip_serializing_if = "Option::is_none")] + pub reason: Option, } impl Display for HealthCheckResult { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match &self.reason { - Some(reason) => write!( - f, - "{} ({})\n\t\t\t{}", - self.health_status, self.response_code, reason - ), - None => write!(f, "{} ({})", self.health_status, self.response_code), + Some(reason) => write!(f, "{} ({})\n\t\t\t{}", self.status, self.code, reason), + None => write!(f, "{} ({})", self.status, self.code), } } } @@ -185,15 +133,15 @@ impl From, tonic::Status>> for Healt Ok(response) => { let response = response.into_inner(); Self { - health_status: response.into(), - response_code: ClientCode::Grpc(Code::Ok), - reason: Self::reason_from_health_check_response(&response), + status: response.into(), + code: StatusCode::OK, + reason: response.reason(), } } Err(status) => Self { - health_status: HealthStatus::Unknown, - response_code: ClientCode::Grpc(status.code()), - reason: Some(format!("gRPC health check failed: {}", status)), + status: HealthStatus::Unknown, + code: grpc_to_http_code(status.code()), + reason: Some(status.message().to_string()), }, } } @@ -205,7 +153,7 @@ impl From, tonic::Status>> for Healt #[derive(Deserialize)] pub struct OptionalHealthCheckResponseBody { /// `HEALTHY`, `UNHEALTHY`, or `UNKNOWN`. Although `HEALTHY` is already implied without a body. - pub health_status: HealthStatus, + pub status: HealthStatus, /// Optional reason for the health check result status being `UNHEALTHY` or `UNKNOWN`. /// May be omitted overall if the health check was successful. #[serde(default)] diff --git a/src/models.rs b/src/models.rs index 869e1f05..ca819628 100644 --- a/src/models.rs +++ b/src/models.rs @@ -23,13 +23,13 @@ use serde::{Deserialize, Serialize}; use crate::{ clients::detector::{ContentAnalysisResponse, ContextType}, - health::ClientHealth, + health::HealthCheckCache, pb, }; #[derive(Clone, Debug, Serialize)] pub struct InfoResponse { - pub client_health: ClientHealth, + pub services: HealthCheckCache, } #[derive(Debug, Clone, Deserialize)] diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 500ac1da..52444f28 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -39,7 +39,7 @@ use crate::{ TgisClient, }, config::{DetectorType, GenerationProvider, OrchestratorConfig}, - health::ClientHealth, + health::HealthCheckCache, models::{ ContextDocsHttpRequest, DetectionOnGeneratedHttpRequest, DetectorParams, GenerationWithDetectionHttpRequest, GuardrailsConfig, GuardrailsHttpRequest, @@ -81,7 +81,7 @@ impl Context { #[cfg_attr(test, derive(Default))] pub struct Orchestrator { ctx: Arc, - client_health: Arc>, + client_health: Arc>, } impl Orchestrator { @@ -93,7 +93,7 @@ impl Orchestrator { let ctx = Arc::new(Context { config, clients }); let orchestrator = Self { ctx, - client_health: Arc::new(RwLock::new(ClientHealth::default())), + client_health: Arc::new(RwLock::new(HealthCheckCache::default())), }; debug!("running start up checks"); orchestrator.on_start_up(start_up_health_check).await?; @@ -114,25 +114,25 @@ impl Orchestrator { info!("Probing client health..."); let client_health = self.client_health(true).await; // Results of probe do not affect orchestrator start-up. - info!("Client health: {client_health:?}"); // TODO: re-impl Display + info!("Client health:\n{client_health}"); } Ok(()) } /// Returns client health state. - pub async fn client_health(&self, probe: bool) -> ClientHealth { + pub async fn client_health(&self, probe: bool) -> HealthCheckCache { let initialized = !self.client_health.read().await.is_empty(); if probe || !initialized { debug!("refreshing client health"); let now = Instant::now(); - let mut state = ClientHealth::with_capacity(self.ctx.clients.len()); + let mut health = HealthCheckCache::with_capacity(self.ctx.clients.len()); // TODO: perform health checks concurrently? for (key, client) in self.ctx.clients.iter() { let result = client.health().await; - state.insert(key.into(), result); + health.insert(key.into(), result); } let mut client_health = self.client_health.write().await; - *client_health = state; + *client_health = health; debug!( "refreshing client health completed in {:.2?}ms", now.elapsed().as_millis() diff --git a/src/server.rs b/src/server.rs index 6343d29a..0ad70a3e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -295,8 +295,8 @@ async fn info( State(state): State>, Query(params): Query, ) -> Result, Error> { - let client_health = state.orchestrator.client_health(params.probe).await; - Ok(Json(InfoResponse { client_health })) + let services = state.orchestrator.client_health(params.probe).await; + Ok(Json(InfoResponse { services })) } async fn classification_with_gen( From ad1a562bb14935cda4b10ef403f80f80f2a925d6 Mon Sep 17 00:00:00 2001 From: declark1 Date: Fri, 4 Oct 2024 11:46:28 -0700 Subject: [PATCH 06/50] Update openapi spec Signed-off-by: declark1 --- docs/api/orchestrator_openapi_0_1_0.yaml | 346 +++++++++++------------ src/orchestrator.rs | 4 +- 2 files changed, 166 insertions(+), 184 deletions(-) diff --git a/docs/api/orchestrator_openapi_0_1_0.yaml b/docs/api/orchestrator_openapi_0_1_0.yaml index d7196d3e..9b375f74 100644 --- a/docs/api/orchestrator_openapi_0_1_0.yaml +++ b/docs/api/orchestrator_openapi_0_1_0.yaml @@ -3,12 +3,12 @@ info: title: FMS Orchestrator API version: 0.1.0 tags: - - name: Task - Text Generation, with detection - description: Detections on text generation model input and/or output - - name: Task - Detection - description: Standalone detections - - name: Task - Chat Completions, with detection - description: Detections on list of messages comprising a conversation and/or completions from a model + - name: Task - Text Generation, with detection + description: Detections on text generation model input and/or output + - name: Task - Detection + description: Standalone detections + - name: Task - Chat Completions, with detection + description: Detections on list of messages comprising a conversation and/or completions from a model paths: /health: get: @@ -17,7 +17,7 @@ paths: summary: Performs quick liveliness check of the orchestrator service operationId: health responses: - '200': + "200": description: Healthy content: application/json: @@ -42,18 +42,18 @@ paths: type: boolean default: false responses: - '200': + "200": description: Orchestrator successfully probed client health statuses content: application/json: schema: - $ref: '#/components/schemas/HealthProbeResponse' - '503': + $ref: "#/components/schemas/InfoResponse" + "503": description: Orchestrator failed to probe client health statuses content: application/json: schema: - $ref: '#/components/schemas/HealthProbeResponse' + $ref: "#/components/schemas/InfoResponse" /api/v1/task/classification-with-text-generation: post: tags: @@ -65,27 +65,27 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/GuardrailsHttpRequest' + $ref: "#/components/schemas/GuardrailsHttpRequest" required: true responses: - '200': + "200": description: Successful Response content: application/json: schema: - $ref: '#/components/schemas/ClassifiedGeneratedTextResult' - '404': + $ref: "#/components/schemas/ClassifiedGeneratedTextResult" + "404": description: Resource Not Found content: application/json: schema: - $ref: '#/components/schemas/Error' - '422': + $ref: "#/components/schemas/Error" + "422": description: Validation Error content: application/json: schema: - $ref: '#/components/schemas/Error' + $ref: "#/components/schemas/Error" /api/v1/task/server-streaming-classification-with-text-generation: post: tags: @@ -97,27 +97,27 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/GuardrailsHttpRequest' + $ref: "#/components/schemas/GuardrailsHttpRequest" required: true responses: - '200': + "200": description: Successful Response content: application/json: schema: - $ref: '#/components/schemas/ClassifiedGeneratedTextStreamResult' - '404': + $ref: "#/components/schemas/ClassifiedGeneratedTextStreamResult" + "404": description: Resource Not Found content: application/json: schema: - $ref: '#/components/schemas/Error' - '422': + $ref: "#/components/schemas/Error" + "422": description: Validation Error content: application/json: schema: - $ref: '#/components/schemas/Error' + $ref: "#/components/schemas/Error" /api/v2/text/generation-detection: post: tags: @@ -129,27 +129,27 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/GenerationDetectionRequest' + $ref: "#/components/schemas/GenerationDetectionRequest" required: true responses: - '200': + "200": description: Successful Response content: application/json: schema: - $ref: '#/components/schemas/GenerationDetectionResponse' - '404': + $ref: "#/components/schemas/GenerationDetectionResponse" + "404": description: Resource Not Found content: application/json: schema: - $ref: '#/components/schemas/Error' - '422': + $ref: "#/components/schemas/Error" + "422": description: Validation Error content: application/json: schema: - $ref: '#/components/schemas/Error' + $ref: "#/components/schemas/Error" /api/v2/text/detection/content: post: @@ -162,27 +162,27 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/DetectionContentRequest' + $ref: "#/components/schemas/DetectionContentRequest" required: true responses: - '200': + "200": description: Successful Response content: application/json: schema: - $ref: '#/components/schemas/DetectionContentResponse' - '404': + $ref: "#/components/schemas/DetectionContentResponse" + "404": description: Resource Not Found content: application/json: schema: - $ref: '#/components/schemas/Error' - '422': + $ref: "#/components/schemas/Error" + "422": description: Validation Error content: application/json: schema: - $ref: '#/components/schemas/Error' + $ref: "#/components/schemas/Error" /api/v2/text/detection/chat: post: tags: @@ -194,27 +194,27 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/DetectionChatRequest' + $ref: "#/components/schemas/DetectionChatRequest" required: true responses: - '200': + "200": description: Successful Response content: application/json: schema: - $ref: '#/components/schemas/DetectionChatResponse' - '404': + $ref: "#/components/schemas/DetectionChatResponse" + "404": description: Resource Not Found content: application/json: schema: - $ref: '#/components/schemas/Error' - '422': + $ref: "#/components/schemas/Error" + "422": description: Validation Error content: application/json: schema: - $ref: '#/components/schemas/Error' + $ref: "#/components/schemas/Error" /api/v2/text/detection/context: post: @@ -227,27 +227,27 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/DetectionContextDocsRequest' + $ref: "#/components/schemas/DetectionContextDocsRequest" required: true responses: - '200': + "200": description: Successful Response content: application/json: schema: - $ref: '#/components/schemas/DetectionContextDocsResponse' - '404': + $ref: "#/components/schemas/DetectionContextDocsResponse" + "404": description: Resource Not Found content: application/json: schema: - $ref: '#/components/schemas/Error' - '422': + $ref: "#/components/schemas/Error" + "422": description: Validation Error content: application/json: schema: - $ref: '#/components/schemas/Error' + $ref: "#/components/schemas/Error" /api/v2/text/detection/generated: post: tags: @@ -259,122 +259,104 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/GeneratedTextDetectionRequest' + $ref: "#/components/schemas/GeneratedTextDetectionRequest" required: true responses: - '200': + "200": description: Successful Response content: application/json: schema: - $ref: '#/components/schemas/GeneratedTextDetectionResponse' - '404': + $ref: "#/components/schemas/GeneratedTextDetectionResponse" + "404": description: Resource Not Found content: application/json: schema: - $ref: '#/components/schemas/Error' - '422': + $ref: "#/components/schemas/Error" + "422": description: Validation Error content: application/json: schema: - $ref: '#/components/schemas/Error' - + $ref: "#/components/schemas/Error" + /api/v2/chat/completions-detection: post: - tags: - - Task - Chat Completions, with detection - operationId: >- - api_v2_chat_completions_detection_handler - summary: Creates a model response with detections for the given chat conversation. - requestBody: - required: true - content: - application/json: - schema: - $ref: "#/components/schemas/GuardrailsCreateChatCompletionRequest" - responses: - '200': - description: Successful Response - content: - application/json: - schema: - $ref: "#/components/schemas/GuardrailsCreateChatCompletionResponse" - '404': - description: Resource Not Found - content: - application/json: - schema: - $ref: '#/components/schemas/Error' - '422': - description: Validation Error - content: - application/json: - schema: - $ref: '#/components/schemas/Error' + tags: + - Task - Chat Completions, with detection + operationId: >- + api_v2_chat_completions_detection_handler + summary: Creates a model response with detections for the given chat conversation. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/GuardrailsCreateChatCompletionRequest" + responses: + "200": + description: Successful Response + content: + application/json: + schema: + $ref: "#/components/schemas/GuardrailsCreateChatCompletionResponse" + "404": + description: Resource Not Found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "422": + description: Validation Error + content: + application/json: + schema: + $ref: "#/components/schemas/Error" components: schemas: HealthStatus: - type: string - enum: - - HEALTHY - - UNHEALTHY - - UNKNOWN - title: Health Status + type: string + enum: + - HEALTHY + - UNHEALTHY + - UNKNOWN + title: Health Status HealthCheckResult: - oneOf: - - properties: - health_status: - $ref: '#/components/schemas/HealthStatus' - response_code: - type: string - title: Response Code - example: "HTTP 200 OK" - reason: - type: string - title: Reason - example: "Service not found" - required: - - health_status - - response_code - - additionalProperties: - $ref: '#/components/schemas/HealthStatus' + properties: + status: + $ref: "#/components/schemas/HealthStatus" + code: + type: string + title: Response Code + example: 200 + reason: + type: string + title: Reason + example: "Not Found" + required: + - status type: object - title: Health Check Response - HealthProbeResponse: + title: Health Check Result + InfoResponse: properties: services: type: object title: Health status for each client service - properties: - generation: - type: object - title: Generation Services - items: - $ref: '#/components/schemas/HealthCheckResult' - detectors: - type: object - title: Detector Services - items: - $ref: '#/components/schemas/HealthCheckResult' - chunkers: - type: object - title: Chunker Services - items: - $ref: '#/components/schemas/HealthCheckResult' + items: + $ref: "#/components/schemas/HealthCheckResult" required: - services type: object - title: Health Probe Response + title: Info Response DetectionContentRequest: properties: detectors: type: object title: Detectors default: {} - example: + example: hap-v1-model-en: {} content: type: string @@ -388,7 +370,7 @@ components: detections: type: array items: - $ref: '#/components/schemas/DetectionContentResponseObject' + $ref: "#/components/schemas/DetectionContentResponseObject" additionalProperties: false required: ["detections"] type: object @@ -422,7 +404,7 @@ components: "detection_type": "HAP", "detection": "has_HAP", "detector_id": "hap-v1-model-en", - "score": 0.999 + "score": 0.999, } DetectionChatRequest: @@ -431,7 +413,7 @@ components: type: object title: Detectors default: {} - example: + example: chat-v1-model-en: {} messages: title: Chat Messages @@ -465,14 +447,14 @@ components: title: Detections on entire history of chat messages title: Chat Detection Response required: ["detections"] - + DetectionContextDocsRequest: properties: detectors: type: object title: Detectors default: {} - example: + example: context-v1-model-en: {} content: type: string @@ -500,7 +482,7 @@ components: detections: type: array items: - $ref: '#/components/schemas/DetectionContextDocsResponseObject' + $ref: "#/components/schemas/DetectionContextDocsResponseObject" required: ["detections"] title: Context Docs Detection Response DetectionContextDocsResponseObject: @@ -516,9 +498,9 @@ components: title: Score evidence: anyOf: - - items: - $ref: '#/components/schemas/EvidenceObj' - type: array + - items: + $ref: "#/components/schemas/EvidenceObj" + type: array title: Context Docs Detection Response Object GenerationDetectionRequest: @@ -533,11 +515,11 @@ components: type: object title: Detectors default: {} - example: + example: generation-detection-v1-model-en: {} text_gen_parameters: allOf: - - $ref: '#/components/schemas/GuardrailsTextGenerationParameters' + - $ref: "#/components/schemas/GuardrailsTextGenerationParameters" type: object required: ["model_id", "prompt", "detectors"] title: Generation-Detection Request @@ -566,7 +548,7 @@ components: title: Input token Count title: Generation Detection Response required: ["generated_text", "detections"] - + GeneratedTextDetectionRequest: properties: prompt: @@ -579,7 +561,7 @@ components: type: object title: Detectors default: {} - example: + example: generated-detection-v1-model-en: {} type: object required: ["generated_text", "prompt", "detectors"] @@ -589,7 +571,7 @@ components: detections: type: array items: - $ref: '#/components/schemas/GeneratedTextDetectionResponseObject' + $ref: "#/components/schemas/GeneratedTextDetectionResponseObject" required: ["detections"] title: Generated Text Detection Response GeneratedTextDetectionResponseObject: @@ -614,10 +596,10 @@ components: title: Generated Text token_classification_results: anyOf: - - $ref: '#/components/schemas/TextGenTokenClassificationResults' + - $ref: "#/components/schemas/TextGenTokenClassificationResults" finish_reason: anyOf: - - $ref: '#/components/schemas/FinishReason' + - $ref: "#/components/schemas/FinishReason" generated_token_count: anyOf: - type: integer @@ -633,19 +615,19 @@ components: warnings: anyOf: - items: - $ref: '#/components/schemas/InputWarning' + $ref: "#/components/schemas/InputWarning" type: array title: Warnings tokens: anyOf: - items: - $ref: '#/components/schemas/GeneratedToken' + $ref: "#/components/schemas/GeneratedToken" type: array title: Tokens input_tokens: anyOf: - items: - $ref: '#/components/schemas/GeneratedToken' + $ref: "#/components/schemas/GeneratedToken" type: array title: Input Tokens additionalProperties: false @@ -660,10 +642,10 @@ components: title: Generated Text token_classification_results: anyOf: - - $ref: '#/components/schemas/TextGenTokenClassificationResults' + - $ref: "#/components/schemas/TextGenTokenClassificationResults" finish_reason: anyOf: - - $ref: '#/components/schemas/FinishReason' + - $ref: "#/components/schemas/FinishReason" generated_token_count: anyOf: - type: integer @@ -679,19 +661,19 @@ components: warnings: anyOf: - items: - $ref: '#/components/schemas/InputWarning' + $ref: "#/components/schemas/InputWarning" type: array title: Warnings tokens: anyOf: - items: - $ref: '#/components/schemas/GeneratedToken' + $ref: "#/components/schemas/GeneratedToken" type: array title: Tokens input_tokens: anyOf: - items: - $ref: '#/components/schemas/GeneratedToken' + $ref: "#/components/schemas/GeneratedToken" type: array title: Input Tokens processed_index: @@ -706,18 +688,18 @@ components: type: object title: Classified Generated Text Stream Result TextGenTokenClassificationResults: - # By default open-api spec consider all fields as optional + # By default open-api spec consider all fields as optional properties: input: anyOf: - items: - $ref: '#/components/schemas/TokenClassificationResult' + $ref: "#/components/schemas/TokenClassificationResult" type: array title: Input output: anyOf: - items: - $ref: '#/components/schemas/TokenClassificationResult' + $ref: "#/components/schemas/TokenClassificationResult" type: array title: Output additionalProperties: false @@ -765,7 +747,7 @@ components: default: {} required: - detectors - + GuardrailsCreateChatCompletionResponse: title: Guardrails Chat Completion Response description: Guardrails chat completion response (adds detections on OpenAI chat completion) @@ -778,7 +760,7 @@ components: warnings: type: array items: - $ref: '#/components/schemas/Warning' + $ref: "#/components/schemas/Warning" required: - detections @@ -800,27 +782,27 @@ components: output: pii-v1: {} conversation-detector: {} - + ChatCompletionsDetections: title: Chat Completions Detections properties: input: type: array items: - $ref: '#/components/schemas/MessageDetections' + $ref: "#/components/schemas/MessageDetections" title: Detections on input to chat completions default: {} output: type: array items: - $ref: '#/components/schemas/ChoiceDetections' + $ref: "#/components/schemas/ChoiceDetections" title: Detections on output of chat completions default: {} default: {} example: input: - message_index: 0 - results: + results: - { "start": 0, "end": 80, @@ -828,7 +810,7 @@ components: "detection_type": "HAP", "detection": "has_HAP", "detector_id": "hap-v1-model-en", # Future addition - "score": 0.999 + "score": 0.999, } output: - choice_index: 0 @@ -841,15 +823,15 @@ components: "detection_type": "HAP", "detection": "has_HAP", "detector_id": "hap-v1-model-en", # Future addition - "score": 0.999 + "score": 0.999, } - { "detection_type": "string", "detection": "string", "detector_id": "relevance-v1-en", # Future addition - "score": 0 + "score": 0, } - + MessageDetections: title: Message Detections properties: @@ -863,9 +845,9 @@ components: type: array items: anyOf: - - $ref: '#/components/schemas/DetectionContentResponseObject' - - $ref: '#/components/schemas/DetectionContextDocsResponseObject' - - $ref: '#/components/schemas/GeneratedTextDetectionResponseObject' + - $ref: "#/components/schemas/DetectionContentResponseObject" + - $ref: "#/components/schemas/DetectionContextDocsResponseObject" + - $ref: "#/components/schemas/GeneratedTextDetectionResponseObject" required: - message_index ChoiceDetections: @@ -881,9 +863,9 @@ components: type: array items: anyOf: - - $ref: '#/components/schemas/DetectionContentResponseObject' - - $ref: '#/components/schemas/DetectionContextDocsResponseObject' - - $ref: '#/components/schemas/GeneratedTextDetectionResponseObject' + - $ref: "#/components/schemas/DetectionContentResponseObject" + - $ref: "#/components/schemas/DetectionContextDocsResponseObject" + - $ref: "#/components/schemas/GeneratedTextDetectionResponseObject" required: - choice_index @@ -948,7 +930,7 @@ components: evidence: anyOf: - items: - $ref: '#/components/schemas/Evidence' + $ref: "#/components/schemas/Evidence" type: array type: object required: @@ -1010,7 +992,7 @@ components: title: Inputs guardrail_config: allOf: - - $ref: '#/components/schemas/GuardrailsConfig' + - $ref: "#/components/schemas/GuardrailsConfig" default: input: masks: [] @@ -1019,7 +1001,7 @@ components: models: {} text_gen_parameters: allOf: - - $ref: '#/components/schemas/GuardrailsTextGenerationParameters' + - $ref: "#/components/schemas/GuardrailsTextGenerationParameters" type: object required: - model_id @@ -1059,7 +1041,7 @@ components: title: Max Time exponential_decay_length_penalty: allOf: - - $ref: '#/components/schemas/ExponentialDecayLengthPenalty' + - $ref: "#/components/schemas/ExponentialDecayLengthPenalty" stop_sequences: items: type: string @@ -1094,7 +1076,7 @@ components: properties: id: allOf: - - $ref: '#/components/schemas/InputWarningReason' + - $ref: "#/components/schemas/InputWarningReason" message: type: string title: Message diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 52444f28..48e2d2b2 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -123,7 +123,7 @@ impl Orchestrator { pub async fn client_health(&self, probe: bool) -> HealthCheckCache { let initialized = !self.client_health.read().await.is_empty(); if probe || !initialized { - debug!("refreshing client health"); + debug!("refreshing health cache"); let now = Instant::now(); let mut health = HealthCheckCache::with_capacity(self.ctx.clients.len()); // TODO: perform health checks concurrently? @@ -134,7 +134,7 @@ impl Orchestrator { let mut client_health = self.client_health.write().await; *client_health = health; debug!( - "refreshing client health completed in {:.2?}ms", + "refreshing health cache completed in {:.2?}ms", now.elapsed().as_millis() ); } From 9a615efff7e6128f191a9b0f7098052c72232925 Mon Sep 17 00:00:00 2001 From: declark1 <44146800+declark1@users.noreply.github.com> Date: Mon, 7 Oct 2024 12:56:21 -0700 Subject: [PATCH 07/50] Updates to align OpenAI Chat Completions with current spec and include OpenAI-specific items, drop Completions API. Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> --- Cargo.lock | 2 - Cargo.toml | 1 - src/clients/openai.rs | 481 +++++++++++++++++++++++------------------- 3 files changed, 262 insertions(+), 222 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a6b64404..60da0bbb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -607,7 +607,6 @@ dependencies = [ "http-serde", "hyper", "hyper-util", - "indexmap 2.5.0", "mio", "prost", "reqwest", @@ -1057,7 +1056,6 @@ checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5" dependencies = [ "equivalent", "hashbrown 0.14.5", - "serde", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 272713c4..ef13667e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,7 +43,6 @@ url = "2.5.2" uuid = { version = "1.10.0", features = ["v4", "fast-rng"] } async-trait = "0.1.81" async-stream = "0.3.5" -indexmap = { version = "2.5.0", features = ["serde"] } http-serde = "2.1.1" [build-dependencies] diff --git a/src/clients/openai.rs b/src/clients/openai.rs index 19bbba2d..8dd1f03c 100644 --- a/src/clients/openai.rs +++ b/src/clients/openai.rs @@ -2,7 +2,6 @@ use std::collections::HashMap; use async_trait::async_trait; use hyper::StatusCode; -use indexmap::IndexMap; use serde::{Deserialize, Serialize}; use super::{Client, Error, HttpClient}; @@ -34,21 +33,6 @@ impl OpenAiClient { }), } } - - pub async fn completions( - &self, - request: CompletionRequest, - ) -> Result { - let url = self.client.base_url().join("/v1/completions").unwrap(); - let response = self.client.post(url).json(&request).send().await?; - match response.status() { - StatusCode::OK => Ok(response.json().await?), - _ => Err(Error::Http { - code: response.status(), - message: "".into(), // TODO - }), - } - } } #[cfg_attr(test, faux::methods)] @@ -63,124 +47,277 @@ impl Client for OpenAiClient { } } -/// Usage statistics for a completion. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Usage { - /// Number of tokens in the generated completion. - pub completion_tokens: u32, - /// Number of tokens in the prompt. - pub prompt_tokens: u32, - /// Total number of tokens used in the request (prompt + completion). - pub total_tokens: u32, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(untagged)] -pub enum StopTokens { - Array(Vec), - String(String), -} - -// Chat completions API types - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionRequest { - /// ID of the model to use. - pub model: String, /// A list of messages comprising the conversation so far. pub messages: Vec, - #[serde(default)] + /// ID of the model to use. + pub model: String, + /// Whether or not to store the output of this chat completion request. + #[serde(skip_serializing_if = "Option::is_none")] + pub store: Option, + /// Developer-defined tags and values. + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub frequency_penalty: Option, /// Modify the likelihood of specified tokens appearing in the completion. - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub logit_bias: Option>, /// Whether to return log probabilities of the output tokens or not. /// If true, returns the log probabilities of each output token returned in the content of message. - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub logprobs: Option, /// An integer between 0 and 20 specifying the number of most likely tokens to return /// at each token position, each with an associated log probability. /// logprobs must be set to true if this parameter is used. - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub top_logprobs: Option, - /// The maximum number of tokens that can be generated in the chat completion. - #[serde(default)] + /// The maximum number of tokens that can be generated in the chat completion. (DEPRECATED) + #[serde(skip_serializing_if = "Option::is_none")] pub max_tokens: Option, + /// An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens. + #[serde(skip_serializing_if = "Option::is_none")] + pub max_completion_tokens: Option, /// How many chat completion choices to generate for each input message. - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub n: Option, /// Positive values penalize new tokens based on whether they appear in the text so far, /// increasing the model's likelihood to talk about new topics. - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub presence_penalty: Option, - //#[serde(default)] - //pub response_format: Option, + /// An object specifying the format that the model must output. + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, /// If specified, our system will make a best effort to sample deterministically, /// such that repeated requests with the same seed and parameters should return the same result. - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub seed: Option, + /// Specifies the latency tier to use for processing the request. + #[serde(skip_serializing_if = "Option::is_none")] + pub service_tier: Option, /// Up to 4 sequences where the API will stop generating further tokens. - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub stop: Option, /// If set, partial message deltas will be sent, like in ChatGPT. /// Tokens will be sent as data-only server-sent events as they become available, /// with the stream terminated by a data: [DONE] message. - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub stream: Option, + /// Options for streaming response. Only set this when you set stream: true. + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, /// What sampling temperature to use, between 0 and 2. /// Higher values like 0.8 will make the output more random, /// while lower values like 0.2 will make it more focused and deterministic. - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, /// An alternative to sampling with temperature, called nucleus sampling, /// where the model considers the results of the tokens with top_p probability mass. /// So 0.1 means only the tokens comprising the top 10% probability mass are considered. - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub top_p: Option, + /// A list of tools the model may call. + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option, + /// Controls which (if any) tool is called by the model. + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + /// Whether to enable parallel function calling during tool use. + #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + /// A unique identifier representing your end-user. + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, // Additional vllm params - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub best_of: Option, - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub use_beam_search: Option, - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub top_k: Option, - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub min_p: Option, - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub repetition_penalty: Option, - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub length_penalty: Option, - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub early_stopping: Option, - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub ignore_eos: Option, - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub min_tokens: Option, - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub stop_token_ids: Option>, - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub skip_special_tokens: Option, - #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] pub spaces_between_special_tokens: Option, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResponseFormat { + /// The type of response format being defined. + #[serde(rename = "type")] + pub r#type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub json_schema: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonSchema { + /// The name of the response format. + pub name: String, + /// A description of what the response format is for, used by the model to determine how to respond in the format. + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// The schema for the response format, described as a JSON Schema object. + #[serde(skip_serializing_if = "Option::is_none")] + pub schema: Option, + /// Whether to enable strict schema adherence when generating the output. + #[serde(skip_serializing_if = "Option::is_none")] + pub strict: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Tool { + /// The type of the tool. + #[serde(rename = "type")] + pub r#type: String, + pub function: ToolFunction, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolFunction { + /// The name of the function to be called. + pub name: String, + /// A description of what the function does, used by the model to choose when and how to call the function. + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// The parameters the functions accepts, described as a JSON Schema object. + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option, + /// Whether to enable strict schema adherence when generating the function call. + #[serde(skip_serializing_if = "Option::is_none")] + pub strict: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ToolChoice { + /// `none` means the model will not call any tool and instead generates a message. + /// `auto` means the model can pick between generating a message or calling one or more tools. + /// `required` means the model must call one or more tools. + String, + /// Specifies a tool the model should use. Use to force the model to call a specific function. + Object(ToolChoiceObject), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolChoiceObject { + /// The type of the tool. + #[serde(rename = "type")] + pub r#type: String, + pub function: Function, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StreamOptions { + /// If set, an additional chunk will be streamed before the data: [DONE] message. + /// The usage field on this chunk shows the token usage statistics for the entire + /// request, and the choices field will always be an empty array. All other chunks + /// will also include a usage field, but with a null value. + #[serde(skip_serializing_if = "Option::is_none")] + pub include_usage: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonSchemaObject { + pub id: String, + pub schema: String, + pub title: String, + pub description: Option, + #[serde(rename = "type")] + pub r#type: String, + pub properties: Option>, + pub required: Option>, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Message { + /// The role of the messages author. pub role: String, - pub content: String, + /// The contents of the message. + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + /// An optional name for the participant. #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, + /// The refusal message by the assistant. (assistant message only) + #[serde(skip_serializing_if = "Option::is_none")] + pub refusal: Option, + /// The tool calls generated by the model, such as function calls. (assistant message only) + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + /// Tool call that this message is responding to. (tool message only) + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, } -impl Message { - pub fn new(role: &str, content: &str, name: Option<&str>) -> Self { - Self { - role: role.into(), - content: content.into(), - name: name.map(|s| s.into()), - } - } +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum Content { + /// The text contents of the message. + String(String), + /// Array of content parts. + Array(Vec), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ContentPart { + /// The type of the content part. + #[serde(rename = "type")] + pub r#type: String, + /// Text content + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + /// Image content + #[serde(skip_serializing_if = "Option::is_none")] + pub image_url: Option, + /// The refusal message generated by the model. (assistant message only) + #[serde(skip_serializing_if = "Option::is_none")] + pub refusal: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImageUrl { + /// Either a URL of the image or the base64 encoded image data. + pub url: String, + /// Specifies the detail level of the image. + #[serde(skip_serializing_if = "Option::is_none")] + pub detail: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + /// The ID of the tool call. + pub id: String, + /// The type of the tool. + #[serde(rename = "type")] + pub r#type: String, + /// The function that the model called. + pub function: Function, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Function { + /// The name of the function to call. + pub name: String, + /// The arguments to call the function with, as generated by the model in JSON format. + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option>, } /// Represents a chat completion response returned by model, based on the provided input. @@ -188,17 +325,20 @@ impl Message { pub struct ChatCompletionResponse { /// A unique identifier for the chat completion. pub id: String, - /// The object type, which is always `chat.completion`. - pub object: String, + /// A list of chat completion choices. Can be more than one if n is greater than 1. + pub choices: Vec, /// The Unix timestamp (in seconds) of when the chat completion was created. pub created: i64, /// The model used for the chat completion. pub model: String, + /// The service tier used for processing the request. + /// This field is only included if the `service_tier` parameter is specified in the request. + pub service_tier: Option, /// This fingerprint represents the backend configuration that the model runs with. - #[serde(skip_serializing_if = "Option::is_none")] - pub system_fingerprint: Option, - /// A list of chat completion choices. Can be more than one if n is greater than 1. - pub choices: Vec, + #[serde(default)] + pub system_fingerprint: String, + /// The object type, which is always `chat.completion`. + pub object: String, /// Usage statistics for the completion request. pub usage: Usage, } @@ -206,31 +346,35 @@ pub struct ChatCompletionResponse { /// A chat completion choice. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionChoice { + /// The reason the model stopped generating tokens. + pub finish_reason: String, /// The index of the choice in the list of choices. pub index: usize, /// A chat completion message generated by the model. pub message: ChatCompletionMessage, /// Log probability information for the choice. pub logprobs: Option, - /// The reason the model stopped generating tokens. - pub finish_reason: String, } /// A chat completion message generated by the model. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionMessage { /// The contents of the message. - #[serde(skip_serializing_if = "Option::is_none")] pub content: Option, + /// The refusal message generated by the model. + pub refusal: Option, + #[serde(default)] + pub tool_calls: Vec, /// The role of the author of this message. - #[serde(skip_serializing_if = "Option::is_none")] - pub role: Option, + pub role: String, } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ChatCompletionLogprobs { - #[serde(skip_serializing_if = "Vec::is_empty")] - pub content: Vec, + /// A list of message content tokens with log probability information. + pub content: Option>, + /// A list of message refusal tokens with log probability information. + pub refusal: Option>, } /// Log probability information for a choice. @@ -240,6 +384,7 @@ pub struct ChatCompletionLogprob { pub token: String, /// The log probability of this token. pub logprob: f32, + pub bytes: Option>, /// List of the most likely tokens and their log probability, at this token position. pub top_logprobs: Option>, } @@ -263,8 +408,13 @@ pub struct ChatCompletionChunk { pub created: i64, /// The model to generate the completion. pub model: String, + /// The service tier used for processing the request. + /// This field is only included if the service_tier parameter is specified in the request. + pub service_tier: Option, + /// This fingerprint represents the backend configuration that the model runs with. + pub system_fingerprint: String, /// The object type, which is always `chat.completion.chunk`. - pub object: &'static str, + pub object: String, #[serde(skip_serializing_if = "Option::is_none")] pub usage: Option, } @@ -273,151 +423,44 @@ pub struct ChatCompletionChunk { pub struct ChatCompletionChunkChoice { /// A chat completion delta generated by streamed model responses. pub delta: ChatCompletionMessage, - /// The index of the choice in the list of choices. - pub index: u32, /// Log probability information for the choice. pub logprobs: Option, /// The reason the model stopped generating tokens. pub finish_reason: Option, + /// The index of the choice in the list of choices. + pub index: u32, } -// Completions API types - +/// Usage statistics for a completion. #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CompletionRequest { - /// ID of the model to use. - pub model: String, - /// The prompt to generate completions for. - /// NOTE: Only supporting a single prompt for now. OpenAI supports a single string, - /// array of strings, array of tokens, or an array of token arrays. - pub prompt: String, - /// Generates best_of completions server-side and returns the "best" (the one with the highest log probability per token). - /// Results cannot be streamed. When used with n, best_of controls the number of candidate completions and n specifies - /// how many to return – best_of must be greater than n. - #[serde(default)] - pub best_of: Option, - /// Echo back the prompt in addition to the completion. - #[serde(default)] - pub echo: Option, - /// Positive values penalize new tokens based on their existing frequency in the text so far, - /// decreasing the model's likelihood to repeat the same line verbatim. - #[serde(default)] - pub frequency_penalty: Option, - /// Modify the likelihood of specified tokens appearing in the completion. - #[serde(default)] - pub logit_bias: Option>, - /// Include the log probabilities on the logprobs most likely output tokens, as well the chosen tokens. - #[serde(default)] - pub logprobs: Option, - /// The maximum number of tokens that can be generated in the completion. - #[serde(default)] - pub max_tokens: Option, - /// How many completions to generate for each prompt. - #[serde(default)] - pub n: Option, - /// Positive values penalize new tokens based on whether they appear in the text so far, - /// increasing the model's likelihood to talk about new topics. - #[serde(default)] - pub presence_penalty: Option, - /// If specified, our system will make a best effort to sample deterministically, - /// such that repeated requests with the same seed and parameters should return the same result. - #[serde(default)] - pub seed: Option, - /// Up to 4 sequences where the API will stop generating further tokens. - /// The returned text will not contain the stop sequence. - #[serde(default)] - pub stop: Option, - /// Whether to stream back partial progress. - /// If set, tokens will be sent as data-only server-sent events as they become available, - /// with the stream terminated by a data: [DONE] message. - #[serde(default)] - pub stream: Option, - #[serde(default)] - /// The suffix that comes after a completion of inserted text. - pub suffix: Option, - /// What sampling temperature to use, between 0 and 2. - /// Higher values like 0.8 will make the output more random, - /// while lower values like 0.2 will make it more focused and deterministic. - #[serde(default)] - pub temperature: Option, - /// An alternative to sampling with temperature, called nucleus sampling, - /// where the model considers the results of the tokens with top_p probability mass. - /// So 0.1 means only the tokens comprising the top 10% probability mass are considered. - #[serde(default)] - pub top_p: Option, - - // Additional vllm params - #[serde(default)] - pub use_beam_search: Option, - #[serde(default)] - pub top_k: Option, - #[serde(default)] - pub min_p: Option, - #[serde(default)] - pub repetition_penalty: Option, - #[serde(default)] - pub length_penalty: Option, - #[serde(default)] - pub early_stopping: Option, - #[serde(default)] - pub stop_token_ids: Option>, - #[serde(default)] - pub ignore_eos: Option, - #[serde(default)] - pub min_tokens: Option, - #[serde(default)] - pub skip_special_tokens: Option, - #[serde(default)] - pub spaces_between_special_tokens: Option, +pub struct Usage { + /// Number of tokens in the generated completion. + pub completion_tokens: u32, + /// Number of tokens in the prompt. + pub prompt_tokens: u32, + /// Total number of tokens used in the request (prompt + completion). + pub total_tokens: u32, + /// Breakdown of tokens used in a completion. + pub completion_token_details: CompletionTokenDetails, + /// Breakdown of tokens used in the prompt. + pub prompt_token_details: PromptTokenDetails, } -// #[derive(Debug, Clone, Serialize, Deserialize)] -// #[serde(untagged)] -// pub enum Prompt { -// Array(Vec), -// String(String), -// } - -/// Represents a completion response from the API. #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CompletionResponse { - /// A unique identifier for the completion. - pub id: String, - /// The object type, which is always `text_completion`. - pub object: String, - /// The Unix timestamp (in seconds) of when the completion was created. - pub created: i64, - /// The model used for the completion. - pub model: String, - /// This fingerprint represents the backend configuration that the model runs with. - #[serde(skip_serializing_if = "Option::is_none")] - pub system_fingerprint: Option, - /// A list of completion choices. Can be more than one if n is greater than 1. - pub choices: Vec, - /// Usage statistics for the completion request. - #[serde(skip_serializing_if = "Option::is_none")] - pub usage: Option, +pub struct CompletionTokenDetails { + pub audio_tokens: u32, + pub reasoning_tokens: u32, } -/// A completion choice. #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CompletionChoice { - /// The index of the choice in the list of choices. - pub index: u32, - /// A chat completion message generated by the model. - pub text: Option, - /// Log probability information for the choice. - pub logprobs: Option, - /// The reason the model stopped generating tokens. - #[serde(skip_serializing_if = "Option::is_none")] - pub finish_reason: Option, +pub struct PromptTokenDetails { + pub audio_tokens: u32, + pub cached_tokens: u32, } -/// Log probability information for a choice. #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CompletionLogprobs { - pub text_offset: Vec, - pub token_logprobs: Vec, - pub tokens: Vec, - pub top_logprobs: Option>>, +#[serde(untagged)] +pub enum StopTokens { + Array(Vec), + String(String), } From c19783b6859e86bc9f20a889fc277457d47ac684 Mon Sep 17 00:00:00 2001 From: Dan Clark <44146800+declark1@users.noreply.github.com> Date: Mon, 7 Oct 2024 14:38:42 -0700 Subject: [PATCH 08/50] Update docs/architecture/adrs/006-detector-type.md Co-authored-by: Gaurav Kumbhat Signed-off-by: Dan Clark <44146800+declark1@users.noreply.github.com> --- docs/architecture/adrs/006-detector-type.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/architecture/adrs/006-detector-type.md b/docs/architecture/adrs/006-detector-type.md index cb3a38da..8579b047 100644 --- a/docs/architecture/adrs/006-detector-type.md +++ b/docs/architecture/adrs/006-detector-type.md @@ -5,7 +5,7 @@ This ADR documents the decision of adding the `type` parameter for detectors in ## Motivation The guardrails orchestrator interfaces with different types of detectors. -Detectors of a given are type are compatible with only a subset of orchestrator endpoints. +Detectors of a given type are compatible with only a subset of orchestrator endpoints. In order to reduce changes of misconfiguration, we need a way to map detectors to be used only with compatible endpoints. From 975ee6968267c4b75d2ec94eb3aa8cbb4201a869 Mon Sep 17 00:00:00 2001 From: Dan Clark <44146800+declark1@users.noreply.github.com> Date: Mon, 7 Oct 2024 14:38:56 -0700 Subject: [PATCH 09/50] Update docs/architecture/adrs/006-detector-type.md Co-authored-by: Gaurav Kumbhat Signed-off-by: Dan Clark <44146800+declark1@users.noreply.github.com> --- docs/architecture/adrs/006-detector-type.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/architecture/adrs/006-detector-type.md b/docs/architecture/adrs/006-detector-type.md index 8579b047..4ccef055 100644 --- a/docs/architecture/adrs/006-detector-type.md +++ b/docs/architecture/adrs/006-detector-type.md @@ -6,7 +6,7 @@ This ADR documents the decision of adding the `type` parameter for detectors in The guardrails orchestrator interfaces with different types of detectors. Detectors of a given type are compatible with only a subset of orchestrator endpoints. -In order to reduce changes of misconfiguration, we need a way to map detectors to be used only with compatible endpoints. +In order to reduce changes of misconfiguration, we need a way to map detectors to be used only with compatible endpoints. This would additionally provide a way for us to refer to a particular detector type within the code, without looking at its `hostname` (url) , which can be error prone. Good example for this is validating if certain detector would work with certain orchestrator endpoint or not. ## Decision From 046e285236b6dd4d6ddd09d97e0dfa65a7149d75 Mon Sep 17 00:00:00 2001 From: Dan Clark <44146800+declark1@users.noreply.github.com> Date: Mon, 7 Oct 2024 14:39:44 -0700 Subject: [PATCH 10/50] Update docs/architecture/adrs/006-detector-type.md Co-authored-by: Gaurav Kumbhat Signed-off-by: Dan Clark <44146800+declark1@users.noreply.github.com> --- docs/architecture/adrs/006-detector-type.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/architecture/adrs/006-detector-type.md b/docs/architecture/adrs/006-detector-type.md index 4ccef055..9ad4dfe3 100644 --- a/docs/architecture/adrs/006-detector-type.md +++ b/docs/architecture/adrs/006-detector-type.md @@ -12,7 +12,7 @@ In order to reduce changes of misconfiguration, we need a way to map detectors t ## Decision We decided to add the `type` parameter to the detectors configuration. -Possible values are `text_contents`, `text_context_chat`, `text_generation` and `text_context_doc`. +Possible values are `text_contents`, `text_chat`, `text_generation` and `text_context_doc`. Below is an example of detector configuration. ```yaml From 42d431b7da2cc90dbc6a1a9fb50429134a70639f Mon Sep 17 00:00:00 2001 From: Dan Clark <44146800+declark1@users.noreply.github.com> Date: Mon, 7 Oct 2024 14:40:16 -0700 Subject: [PATCH 11/50] Update docs/architecture/adrs/006-detector-type.md Co-authored-by: Gaurav Kumbhat Signed-off-by: Dan Clark <44146800+declark1@users.noreply.github.com> --- docs/architecture/adrs/006-detector-type.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/architecture/adrs/006-detector-type.md b/docs/architecture/adrs/006-detector-type.md index 9ad4dfe3..99c4e077 100644 --- a/docs/architecture/adrs/006-detector-type.md +++ b/docs/architecture/adrs/006-detector-type.md @@ -36,7 +36,7 @@ detectors: 5. Not including `type` results in a configuration validation error on orchestrator startup. 6. Detector endpoints are automatically configured based on `type` as follows: * `text_contents` -> `/api/v1/text/contents` - * `text_context_chat` -> `/api/v1/text/context/chat` + * `text_chat` -> `/api/v1/text/chat` * `text_context_doc` -> `/api/v1/text/context/doc` * `text_generation` -> `/api/v1/text/generation` From 4dd61fd9a7bf223833d15288cc49674525781406 Mon Sep 17 00:00:00 2001 From: declark1 <44146800+declark1@users.noreply.github.com> Date: Mon, 7 Oct 2024 14:44:41 -0700 Subject: [PATCH 12/50] Rename TextContextChatDetector to TextChatDetector, update DetectorType and example config Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> --- config/config.yaml | 5 ++--- src/clients/detector.rs | 4 ++-- .../{text_context_chat.rs => text_chat.rs} | 16 ++++++---------- src/config.rs | 2 +- src/orchestrator.rs | 11 ++++------- 5 files changed, 15 insertions(+), 23 deletions(-) rename src/clients/detector/{text_context_chat.rs => text_chat.rs} (59%) diff --git a/config/config.yaml b/config/config.yaml index 6ce600aa..7d277050 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -32,7 +32,7 @@ chunkers: detectors: # Detector ID/name to be used in user requests hap-en: - # Detector type (text_contents, text_generation, text_context_chat, text_context_doc) + # Detector type (text_contents, text_generation, text_chat, text_context_doc) type: text_contents service: hostname: localhost @@ -61,8 +61,7 @@ tls: detector_bundle_no_ca: cert_path: /path/to/client-bundle.pem insecure: true - # Following section can be used to configure the allowed headers that orchestrator will pass to # NLP provider and detectors. Note that, this section takes header keys, not values. # passthrough_headers: -# - header-key \ No newline at end of file +# - header-key diff --git a/src/clients/detector.rs b/src/clients/detector.rs index c00c1dfc..d8efba03 100644 --- a/src/clients/detector.rs +++ b/src/clients/detector.rs @@ -17,8 +17,8 @@ pub mod text_contents; pub use text_contents::*; -pub mod text_context_chat; -pub use text_context_chat::*; +pub mod text_chat; +pub use text_chat::*; pub mod text_context_doc; pub use text_context_doc::*; pub mod text_generation; diff --git a/src/clients/detector/text_context_chat.rs b/src/clients/detector/text_chat.rs similarity index 59% rename from src/clients/detector/text_context_chat.rs rename to src/clients/detector/text_chat.rs index 81e06118..87750895 100644 --- a/src/clients/detector/text_context_chat.rs +++ b/src/clients/detector/text_chat.rs @@ -7,31 +7,27 @@ use crate::{ #[cfg_attr(test, faux::create)] #[derive(Clone)] -pub struct TextContextChatDetectorClient { +pub struct TextChatDetectorClient { client: HttpClient, } #[cfg_attr(test, faux::methods)] -impl TextContextChatDetectorClient { +impl TextChatDetectorClient { pub fn new(client: HttpClient) -> Self { Self { client } } - pub async fn text_context_chat(&self) { - let _url = self - .client - .base_url() - .join("/api/v1/text/context/chat") - .unwrap(); + pub async fn text_chat(&self) { + let _url = self.client.base_url().join("/api/v1/text/chat").unwrap(); todo!() } } #[cfg_attr(test, faux::methods)] #[async_trait] -impl Client for TextContextChatDetectorClient { +impl Client for TextChatDetectorClient { fn name(&self) -> &str { - "text_context_chat_detector" + "text_chat_detector" } async fn health(&self) -> HealthCheckResult { diff --git a/src/config.rs b/src/config.rs index ed002804..209cf028 100644 --- a/src/config.rs +++ b/src/config.rs @@ -160,7 +160,7 @@ pub enum DetectorType { #[default] TextContents, TextGeneration, - TextContextChat, + TextChat, TextContextDoc, } diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 48e2d2b2..394074b6 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -31,8 +31,8 @@ use crate::{ clients::{ create_grpc_client, create_http_client, detector::{ - text_context_doc::ContextType, TextContextChatDetectorClient, - TextContextDocDetectorClient, TextGenerationDetectorClient, + text_context_doc::ContextType, TextChatDetectorClient, TextContextDocDetectorClient, + TextGenerationDetectorClient, }, openai::OpenAiClient, ChunkerClient, ClientMap, GenerationClient, NlpClient, TextContentsDetectorClient, @@ -255,11 +255,8 @@ async fn create_clients(config: &OrchestratorConfig) -> ClientMap { TextGenerationDetectorClient::new(client), ); } - DetectorType::TextContextChat => { - clients.insert( - detector_id.into(), - TextContextChatDetectorClient::new(client), - ); + DetectorType::TextChat => { + clients.insert(detector_id.into(), TextChatDetectorClient::new(client)); } DetectorType::TextContextDoc => { clients.insert( From e7b6226f10d2325f70412a8c6e39ce4ebc861e6c Mon Sep 17 00:00:00 2001 From: declark1 <44146800+declark1@users.noreply.github.com> Date: Mon, 7 Oct 2024 15:09:32 -0700 Subject: [PATCH 13/50] Drop provider from ChatGenerationConfig as it will always useopenai, drop GenerationProvider::OpenAi variant Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> --- config/config.yaml | 1 - config/test.config.yaml | 1 - src/config.rs | 19 ------------------- src/orchestrator.rs | 13 +++---------- 4 files changed, 3 insertions(+), 31 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 7d277050..29bdfc73 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -13,7 +13,6 @@ generation: port: 8033 # Generation server used for chat endpoints # chat_generation: -# provider: openai # service: # hostname: localhost # port: 8080 diff --git a/config/test.config.yaml b/config/test.config.yaml index 983a6234..cdde1905 100644 --- a/config/test.config.yaml +++ b/config/test.config.yaml @@ -4,7 +4,6 @@ generation: hostname: localhost port: 443 # chat_generation: -# provider: openai # service: # hostname: localhost # port: 8080 diff --git a/src/config.rs b/src/config.rs index 209cf028..fba4ab69 100644 --- a/src/config.rs +++ b/src/config.rs @@ -94,8 +94,6 @@ pub enum GenerationProvider { Tgis, #[serde(rename = "nlp")] Nlp, - #[serde(rename = "openai")] - OpenAi, } /// Generation service configuration @@ -112,8 +110,6 @@ pub struct GenerationConfig { #[cfg_attr(test, derive(Default))] #[derive(Clone, Debug, Deserialize)] pub struct ChatGenerationConfig { - /// Generation service provider - pub provider: GenerationProvider, /// Generation service connection information pub service: ServiceConfig, } @@ -299,15 +295,6 @@ impl OrchestratorConfig { // Generation config is valid if let Some(generation) = &self.generation { - // Provider is valid - if !matches!( - generation.provider, - GenerationProvider::Tgis | GenerationProvider::Nlp - ) { - return Err(Error::InvalidGenerationProvider( - "`generation` requires `tgis` or `nlp` provider".into(), - )); - } // Hostname is valid if !is_valid_hostname(&generation.service.hostname) { return Err(Error::InvalidHostname( @@ -318,12 +305,6 @@ impl OrchestratorConfig { // Chat generation config is valid if let Some(chat_generation) = &self.chat_generation { - // Provider is valid - if !matches!(chat_generation.provider, GenerationProvider::OpenAi) { - return Err(Error::InvalidGenerationProvider( - "`chat_generation` requires `openai` provider".into(), - )); - } // Hostname is valid if !is_valid_hostname(&chat_generation.service.hostname) { return Err(Error::InvalidHostname( diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 394074b6..4ea5b73a 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -209,21 +209,14 @@ async fn create_clients(config: &OrchestratorConfig) -> ClientMap { let generation_client = GenerationClient::nlp(nlp_client); clients.insert("generation".to_string(), generation_client); } - GenerationProvider::OpenAi => unimplemented!(), } } // Create chat generation client if let Some(chat_generation) = &config.chat_generation { - match chat_generation.provider { - GenerationProvider::OpenAi => { - let client = - create_http_client(DEFAULT_OPENAI_PORT, &chat_generation.service).await; - let openai_client = OpenAiClient::new(client); - clients.insert("chat_generation".to_string(), openai_client); - } - _ => unimplemented!(), - } + let client = create_http_client(DEFAULT_OPENAI_PORT, &chat_generation.service).await; + let openai_client = OpenAiClient::new(client); + clients.insert("chat_generation".to_string(), openai_client); } // Create chunker clients From 355de7d1b836a45d39e4fad0f43ae8d6383e0553 Mon Sep 17 00:00:00 2001 From: declark1 <44146800+declark1@users.noreply.github.com> Date: Tue, 8 Oct 2024 09:48:26 -0700 Subject: [PATCH 14/50] Split config validation rules into methods Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> --- src/config.rs | 64 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 41 insertions(+), 23 deletions(-) diff --git a/src/config.rs b/src/config.rs index fba4ab69..6f4a911b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -259,7 +259,43 @@ impl OrchestratorConfig { return Err(Error::NoDetectorsConfigured); } - // Detector configs are valid + // Apply validation rules + self.validate_generation_config()?; + self.validate_chat_generation_config()?; + self.validate_detector_configs()?; + self.validate_chunker_configs()?; + + Ok(()) + } + + /// Validates generation config. + fn validate_generation_config(&self) -> Result<(), Error> { + if let Some(generation) = &self.generation { + // Hostname is valid + if !is_valid_hostname(&generation.service.hostname) { + return Err(Error::InvalidHostname( + "`generation` has an invalid hostname".into(), + )); + } + } + Ok(()) + } + + /// Validates chat generation config. + fn validate_chat_generation_config(&self) -> Result<(), Error> { + if let Some(chat_generation) = &self.chat_generation { + // Hostname is valid + if !is_valid_hostname(&chat_generation.service.hostname) { + return Err(Error::InvalidHostname( + "`chat_generation` has an invalid hostname".into(), + )); + } + } + Ok(()) + } + + /// Validates detector configs. + fn validate_detector_configs(&self) -> Result<(), Error> { for (detector_id, detector) in &self.detectors { // Hostname is valid if !is_valid_hostname(&detector.service.hostname) { @@ -280,8 +316,11 @@ impl OrchestratorConfig { }); } } + Ok(()) + } - // Chunker config is valid + /// Validates chunker configs. + fn validate_chunker_configs(&self) -> Result<(), Error> { if let Some(chunkers) = &self.chunkers { for (chunker_id, chunker) in chunkers { // Hostname is valid @@ -292,27 +331,6 @@ impl OrchestratorConfig { } } } - - // Generation config is valid - if let Some(generation) = &self.generation { - // Hostname is valid - if !is_valid_hostname(&generation.service.hostname) { - return Err(Error::InvalidHostname( - "`generation` has an invalid hostname".into(), - )); - } - } - - // Chat generation config is valid - if let Some(chat_generation) = &self.chat_generation { - // Hostname is valid - if !is_valid_hostname(&chat_generation.service.hostname) { - return Err(Error::InvalidHostname( - "`chat_generation` has an invalid hostname".into(), - )); - } - } - Ok(()) } From ee1e7824d09fe2fdb159079b8173dc50e05de804 Mon Sep 17 00:00:00 2001 From: declark1 <44146800+declark1@users.noreply.github.com> Date: Tue, 8 Oct 2024 11:46:37 -0700 Subject: [PATCH 15/50] Add health_service to DetectorConfig and ChatGenerationConfig, add health_client to detector clients and OpenAiClient Co-authored-by: Paul Scoropan <1paulscoropan@gmail.com> Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> --- config/config.yaml | 4 ++++ src/clients/detector/text_chat.rs | 14 ++++++++++--- src/clients/detector/text_contents.rs | 14 ++++++++++--- src/clients/detector/text_context_doc.rs | 14 ++++++++++--- src/clients/detector/text_generation.rs | 14 ++++++++++--- src/clients/openai.rs | 14 ++++++++++--- src/config.rs | 4 ++++ src/orchestrator.rs | 26 +++++++++++++++++++----- 8 files changed, 84 insertions(+), 20 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 29bdfc73..ef5be2c4 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -16,6 +16,7 @@ generation: # service: # hostname: localhost # port: 8080 +# # health_service: # Any chunker servers that will be used by any detectors chunkers: # Chunker ID/name @@ -38,6 +39,9 @@ detectors: port: 8080 # TLS ID/name, optional (detailed in `tls` section) tls: detector + health_service: + hostname: localhost + port: 8081 # Chunker ID/name from `chunkers` section if applicable chunker_id: en_regex # Default score threshold for a detector. If a user diff --git a/src/clients/detector/text_chat.rs b/src/clients/detector/text_chat.rs index 87750895..ae528f66 100644 --- a/src/clients/detector/text_chat.rs +++ b/src/clients/detector/text_chat.rs @@ -9,12 +9,16 @@ use crate::{ #[derive(Clone)] pub struct TextChatDetectorClient { client: HttpClient, + health_client: Option, } #[cfg_attr(test, faux::methods)] impl TextChatDetectorClient { - pub fn new(client: HttpClient) -> Self { - Self { client } + pub fn new(client: HttpClient, health_client: Option) -> Self { + Self { + client, + health_client, + } } pub async fn text_chat(&self) { @@ -31,6 +35,10 @@ impl Client for TextChatDetectorClient { } async fn health(&self) -> HealthCheckResult { - self.client.health().await + if let Some(health_client) = &self.health_client { + health_client.health().await + } else { + self.client.health().await + } } } diff --git a/src/clients/detector/text_contents.rs b/src/clients/detector/text_contents.rs index 50a8013d..c98e7b32 100644 --- a/src/clients/detector/text_contents.rs +++ b/src/clients/detector/text_contents.rs @@ -12,12 +12,16 @@ use crate::{ #[derive(Clone)] pub struct TextContentsDetectorClient { client: HttpClient, + health_client: Option, } #[cfg_attr(test, faux::methods)] impl TextContentsDetectorClient { - pub fn new(client: HttpClient) -> Self { - Self { client } + pub fn new(client: HttpClient, health_client: Option) -> Self { + Self { + client, + health_client, + } } pub async fn text_contents( @@ -63,7 +67,11 @@ impl Client for TextContentsDetectorClient { } async fn health(&self) -> HealthCheckResult { - self.client.health().await + if let Some(health_client) = &self.health_client { + health_client.health().await + } else { + self.client.health().await + } } } diff --git a/src/clients/detector/text_context_doc.rs b/src/clients/detector/text_context_doc.rs index aca518f7..a87a4be4 100644 --- a/src/clients/detector/text_context_doc.rs +++ b/src/clients/detector/text_context_doc.rs @@ -13,12 +13,16 @@ use crate::{ #[derive(Clone)] pub struct TextContextDocDetectorClient { client: HttpClient, + health_client: Option, } #[cfg_attr(test, faux::methods)] impl TextContextDocDetectorClient { - pub fn new(client: HttpClient) -> Self { - Self { client } + pub fn new(client: HttpClient, health_client: Option) -> Self { + Self { + client, + health_client, + } } pub async fn text_context_doc( @@ -64,7 +68,11 @@ impl Client for TextContextDocDetectorClient { } async fn health(&self) -> HealthCheckResult { - self.client.health().await + if let Some(health_client) = &self.health_client { + health_client.health().await + } else { + self.client.health().await + } } } diff --git a/src/clients/detector/text_generation.rs b/src/clients/detector/text_generation.rs index 7fcfe987..6ece86b7 100644 --- a/src/clients/detector/text_generation.rs +++ b/src/clients/detector/text_generation.rs @@ -13,12 +13,16 @@ use crate::{ #[derive(Clone)] pub struct TextGenerationDetectorClient { client: HttpClient, + health_client: Option, } #[cfg_attr(test, faux::methods)] impl TextGenerationDetectorClient { - pub fn new(client: HttpClient) -> Self { - Self { client } + pub fn new(client: HttpClient, health_client: Option) -> Self { + Self { + client, + health_client, + } } pub async fn text_generation( @@ -64,7 +68,11 @@ impl Client for TextGenerationDetectorClient { } async fn health(&self) -> HealthCheckResult { - self.client.health().await + if let Some(health_client) = &self.health_client { + health_client.health().await + } else { + self.client.health().await + } } } diff --git a/src/clients/openai.rs b/src/clients/openai.rs index 8dd1f03c..56ff9b46 100644 --- a/src/clients/openai.rs +++ b/src/clients/openai.rs @@ -11,12 +11,16 @@ use crate::health::HealthCheckResult; #[derive(Clone)] pub struct OpenAiClient { client: HttpClient, + health_client: Option, } #[cfg_attr(test, faux::methods)] impl OpenAiClient { - pub fn new(client: HttpClient) -> Self { - Self { client } + pub fn new(client: HttpClient, health_client: Option) -> Self { + Self { + client, + health_client, + } } pub async fn chat_completions( @@ -43,7 +47,11 @@ impl Client for OpenAiClient { } async fn health(&self) -> HealthCheckResult { - self.client.health().await + if let Some(health_client) = &self.health_client { + health_client.health().await + } else { + self.client.health().await + } } } diff --git a/src/config.rs b/src/config.rs index 6f4a911b..89c4d559 100644 --- a/src/config.rs +++ b/src/config.rs @@ -112,6 +112,8 @@ pub struct GenerationConfig { pub struct ChatGenerationConfig { /// Generation service connection information pub service: ServiceConfig, + /// Generation health service connection information + pub health_service: Option, } /// Chunker parser type @@ -140,6 +142,8 @@ pub struct ChunkerConfig { pub struct DetectorConfig { /// Detector service connection information pub service: ServiceConfig, + /// Detector health service connection information + pub health_service: Option, /// ID of chunker that this detector will use pub chunker_id: String, /// Default threshold with which to filter detector results by score diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 4ea5b73a..40d8af77 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -215,7 +215,12 @@ async fn create_clients(config: &OrchestratorConfig) -> ClientMap { // Create chat generation client if let Some(chat_generation) = &config.chat_generation { let client = create_http_client(DEFAULT_OPENAI_PORT, &chat_generation.service).await; - let openai_client = OpenAiClient::new(client); + let health_client = if let Some(health_service) = &chat_generation.health_service { + Some(create_http_client(DEFAULT_OPENAI_PORT, health_service).await) + } else { + None + }; + let openai_client = OpenAiClient::new(client, health_client); clients.insert("chat_generation".to_string(), openai_client); } @@ -238,23 +243,34 @@ async fn create_clients(config: &OrchestratorConfig) -> ClientMap { // Create detector clients for (detector_id, detector) in &config.detectors { let client = create_http_client(DEFAULT_DETECTOR_PORT, &detector.service).await; + let health_client = if let Some(health_service) = &detector.health_service { + Some(create_http_client(DEFAULT_DETECTOR_PORT, health_service).await) + } else { + None + }; match detector.r#type { DetectorType::TextContents => { - clients.insert(detector_id.into(), TextContentsDetectorClient::new(client)); + clients.insert( + detector_id.into(), + TextContentsDetectorClient::new(client, health_client), + ); } DetectorType::TextGeneration => { clients.insert( detector_id.into(), - TextGenerationDetectorClient::new(client), + TextGenerationDetectorClient::new(client, health_client), ); } DetectorType::TextChat => { - clients.insert(detector_id.into(), TextChatDetectorClient::new(client)); + clients.insert( + detector_id.into(), + TextChatDetectorClient::new(client, health_client), + ); } DetectorType::TextContextDoc => { clients.insert( detector_id.into(), - TextContextDocDetectorClient::new(client), + TextContextDocDetectorClient::new(client, health_client), ); } } From 428c8063ad2cb5deee57c3c63a1524358795856b Mon Sep 17 00:00:00 2001 From: declark1 <44146800+declark1@users.noreply.github.com> Date: Thu, 10 Oct 2024 09:39:49 -0700 Subject: [PATCH 16/50] Move inner client creation back to client constructor Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> --- src/clients/chunker.rs | 11 +-- src/clients/detector.rs | 1 + src/clients/detector/text_chat.rs | 12 +++- src/clients/detector/text_contents.rs | 13 +++- src/clients/detector/text_context_doc.rs | 13 +++- src/clients/detector/text_generation.rs | 13 +++- src/clients/nlp.rs | 11 +-- src/clients/openai.rs | 14 +++- src/clients/tgis.rs | 8 ++- src/orchestrator.rs | 86 ++++++++---------------- 10 files changed, 98 insertions(+), 84 deletions(-) diff --git a/src/clients/chunker.rs b/src/clients/chunker.rs index e0113b97..8f21f821 100644 --- a/src/clients/chunker.rs +++ b/src/clients/chunker.rs @@ -25,8 +25,9 @@ use tokio_stream::wrappers::ReceiverStream; use tonic::{Code, Request, Response, Status, Streaming}; use tracing::info; -use super::{errors::grpc_to_http_code, BoxStream, Client, Error}; +use super::{create_grpc_client, errors::grpc_to_http_code, BoxStream, Client, Error}; use crate::{ + config::ServiceConfig, health::{HealthCheckResult, HealthStatus}, pb::{ caikit::runtime::chunkers::{ @@ -38,6 +39,7 @@ use crate::{ }, }; +const DEFAULT_PORT: u16 = 8085; const MODEL_ID_HEADER_NAME: &str = "mm-model-id"; /// Default chunker that returns span for entire text pub const DEFAULT_MODEL_ID: &str = "whole_doc_chunker"; @@ -54,10 +56,9 @@ pub struct ChunkerClient { #[cfg_attr(test, faux::methods)] impl ChunkerClient { - pub fn new( - client: ChunkersServiceClient, - health_client: HealthClient, - ) -> Self { + pub async fn new(config: &ServiceConfig) -> Self { + let client = create_grpc_client(DEFAULT_PORT, config, ChunkersServiceClient::new).await; + let health_client = create_grpc_client(DEFAULT_PORT, config, HealthClient::new).await; Self { client, health_client, diff --git a/src/clients/detector.rs b/src/clients/detector.rs index d8efba03..80e0cb70 100644 --- a/src/clients/detector.rs +++ b/src/clients/detector.rs @@ -28,6 +28,7 @@ pub use text_generation::*; use super::Error; +const DEFAULT_PORT: u16 = 8080; const DETECTOR_ID_HEADER_NAME: &str = "detector-id"; #[derive(Debug, Clone, Deserialize)] diff --git a/src/clients/detector/text_chat.rs b/src/clients/detector/text_chat.rs index ae528f66..39214c43 100644 --- a/src/clients/detector/text_chat.rs +++ b/src/clients/detector/text_chat.rs @@ -1,7 +1,9 @@ use async_trait::async_trait; +use super::DEFAULT_PORT; use crate::{ - clients::{Client, HttpClient}, + clients::{create_http_client, Client, HttpClient}, + config::ServiceConfig, health::HealthCheckResult, }; @@ -14,7 +16,13 @@ pub struct TextChatDetectorClient { #[cfg_attr(test, faux::methods)] impl TextChatDetectorClient { - pub fn new(client: HttpClient, health_client: Option) -> Self { + pub async fn new(config: &ServiceConfig, health_config: Option<&ServiceConfig>) -> Self { + let client = create_http_client(DEFAULT_PORT, config).await; + let health_client = if let Some(health_config) = health_config { + Some(create_http_client(DEFAULT_PORT, health_config).await) + } else { + None + }; Self { client, health_client, diff --git a/src/clients/detector/text_contents.rs b/src/clients/detector/text_contents.rs index c98e7b32..a146e6a4 100644 --- a/src/clients/detector/text_contents.rs +++ b/src/clients/detector/text_contents.rs @@ -2,9 +2,10 @@ use async_trait::async_trait; use hyper::{HeaderMap, StatusCode}; use serde::{Deserialize, Serialize}; -use super::{DetectorError, DETECTOR_ID_HEADER_NAME}; +use super::{DetectorError, DEFAULT_PORT, DETECTOR_ID_HEADER_NAME}; use crate::{ - clients::{Client, Error, HttpClient}, + clients::{create_http_client, Client, Error, HttpClient}, + config::ServiceConfig, health::HealthCheckResult, }; @@ -17,7 +18,13 @@ pub struct TextContentsDetectorClient { #[cfg_attr(test, faux::methods)] impl TextContentsDetectorClient { - pub fn new(client: HttpClient, health_client: Option) -> Self { + pub async fn new(config: &ServiceConfig, health_config: Option<&ServiceConfig>) -> Self { + let client = create_http_client(DEFAULT_PORT, config).await; + let health_client = if let Some(health_config) = health_config { + Some(create_http_client(DEFAULT_PORT, health_config).await) + } else { + None + }; Self { client, health_client, diff --git a/src/clients/detector/text_context_doc.rs b/src/clients/detector/text_context_doc.rs index a87a4be4..6ef4661d 100644 --- a/src/clients/detector/text_context_doc.rs +++ b/src/clients/detector/text_context_doc.rs @@ -2,9 +2,10 @@ use async_trait::async_trait; use hyper::{HeaderMap, StatusCode}; use serde::{Deserialize, Serialize}; -use super::{DetectorError, DETECTOR_ID_HEADER_NAME}; +use super::{DetectorError, DEFAULT_PORT, DETECTOR_ID_HEADER_NAME}; use crate::{ - clients::{Client, Error, HttpClient}, + clients::{create_http_client, Client, Error, HttpClient}, + config::ServiceConfig, health::HealthCheckResult, models::{DetectionResult, DetectorParams}, }; @@ -18,7 +19,13 @@ pub struct TextContextDocDetectorClient { #[cfg_attr(test, faux::methods)] impl TextContextDocDetectorClient { - pub fn new(client: HttpClient, health_client: Option) -> Self { + pub async fn new(config: &ServiceConfig, health_config: Option<&ServiceConfig>) -> Self { + let client = create_http_client(DEFAULT_PORT, config).await; + let health_client = if let Some(health_config) = health_config { + Some(create_http_client(DEFAULT_PORT, health_config).await) + } else { + None + }; Self { client, health_client, diff --git a/src/clients/detector/text_generation.rs b/src/clients/detector/text_generation.rs index 6ece86b7..b1190daa 100644 --- a/src/clients/detector/text_generation.rs +++ b/src/clients/detector/text_generation.rs @@ -2,9 +2,10 @@ use async_trait::async_trait; use hyper::{HeaderMap, StatusCode}; use serde::Serialize; -use super::{DetectorError, DETECTOR_ID_HEADER_NAME}; +use super::{DetectorError, DEFAULT_PORT, DETECTOR_ID_HEADER_NAME}; use crate::{ - clients::{Client, Error, HttpClient}, + clients::{create_http_client, Client, Error, HttpClient}, + config::ServiceConfig, health::HealthCheckResult, models::DetectionResult, }; @@ -18,7 +19,13 @@ pub struct TextGenerationDetectorClient { #[cfg_attr(test, faux::methods)] impl TextGenerationDetectorClient { - pub fn new(client: HttpClient, health_client: Option) -> Self { + pub async fn new(config: &ServiceConfig, health_config: Option<&ServiceConfig>) -> Self { + let client = create_http_client(DEFAULT_PORT, config).await; + let health_client = if let Some(health_config) = health_config { + Some(create_http_client(DEFAULT_PORT, health_config).await) + } else { + None + }; Self { client, health_client, diff --git a/src/clients/nlp.rs b/src/clients/nlp.rs index c40f9490..f6c873a6 100644 --- a/src/clients/nlp.rs +++ b/src/clients/nlp.rs @@ -21,8 +21,9 @@ use futures::{StreamExt, TryStreamExt}; use ginepro::LoadBalancedChannel; use tonic::{metadata::MetadataMap, Code, Request}; -use super::{errors::grpc_to_http_code, BoxStream, Client, Error}; +use super::{create_grpc_client, errors::grpc_to_http_code, BoxStream, Client, Error}; use crate::{ + config::ServiceConfig, health::{HealthCheckResult, HealthStatus}, pb::{ caikit::runtime::nlp::{ @@ -37,6 +38,7 @@ use crate::{ }, }; +const DEFAULT_PORT: u16 = 8085; const MODEL_ID_HEADER_NAME: &str = "mm-model-id"; #[cfg_attr(test, faux::create)] @@ -48,10 +50,9 @@ pub struct NlpClient { #[cfg_attr(test, faux::methods)] impl NlpClient { - pub fn new( - client: NlpServiceClient, - health_client: HealthClient, - ) -> Self { + pub async fn new(config: &ServiceConfig) -> Self { + let client = create_grpc_client(DEFAULT_PORT, config, NlpServiceClient::new).await; + let health_client = create_grpc_client(DEFAULT_PORT, config, HealthClient::new).await; Self { client, health_client, diff --git a/src/clients/openai.rs b/src/clients/openai.rs index 56ff9b46..584ba436 100644 --- a/src/clients/openai.rs +++ b/src/clients/openai.rs @@ -4,8 +4,10 @@ use async_trait::async_trait; use hyper::StatusCode; use serde::{Deserialize, Serialize}; -use super::{Client, Error, HttpClient}; -use crate::health::HealthCheckResult; +use super::{create_http_client, Client, Error, HttpClient}; +use crate::{config::ServiceConfig, health::HealthCheckResult}; + +const DEFAULT_PORT: u16 = 8080; #[cfg_attr(test, faux::create)] #[derive(Clone)] @@ -16,7 +18,13 @@ pub struct OpenAiClient { #[cfg_attr(test, faux::methods)] impl OpenAiClient { - pub fn new(client: HttpClient, health_client: Option) -> Self { + pub async fn new(config: &ServiceConfig, health_config: Option<&ServiceConfig>) -> Self { + let client = create_http_client(DEFAULT_PORT, config).await; + let health_client = if let Some(health_config) = health_config { + Some(create_http_client(DEFAULT_PORT, health_config).await) + } else { + None + }; Self { client, health_client, diff --git a/src/clients/tgis.rs b/src/clients/tgis.rs index 1a779b09..1195d642 100644 --- a/src/clients/tgis.rs +++ b/src/clients/tgis.rs @@ -20,8 +20,9 @@ use futures::{StreamExt, TryStreamExt}; use ginepro::LoadBalancedChannel; use tonic::Code; -use super::{errors::grpc_to_http_code, BoxStream, Client, Error}; +use super::{create_grpc_client, errors::grpc_to_http_code, BoxStream, Client, Error}; use crate::{ + config::ServiceConfig, health::{HealthCheckResult, HealthStatus}, pb::fmaas::{ generation_service_client::GenerationServiceClient, BatchedGenerationRequest, @@ -30,6 +31,8 @@ use crate::{ }, }; +const DEFAULT_PORT: u16 = 8033; + #[cfg_attr(test, faux::create)] #[derive(Clone)] pub struct TgisClient { @@ -38,7 +41,8 @@ pub struct TgisClient { #[cfg_attr(test, faux::methods)] impl TgisClient { - pub fn new(client: GenerationServiceClient) -> Self { + pub async fn new(config: &ServiceConfig) -> Self { + let client = create_grpc_client(DEFAULT_PORT, config, GenerationServiceClient::new).await; Self { client } } diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 40d8af77..82939d18 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -29,7 +29,6 @@ use uuid::Uuid; use crate::{ clients::{ - create_grpc_client, create_http_client, detector::{ text_context_doc::ContextType, TextChatDetectorClient, TextContextDocDetectorClient, TextGenerationDetectorClient, @@ -45,22 +44,8 @@ use crate::{ GenerationWithDetectionHttpRequest, GuardrailsConfig, GuardrailsHttpRequest, GuardrailsTextGenerationParameters, TextContentDetectionHttpRequest, }, - pb::{ - caikit::runtime::{ - chunkers::chunkers_service_client::ChunkersServiceClient, - nlp::nlp_service_client::NlpServiceClient, - }, - fmaas::generation_service_client::GenerationServiceClient, - grpc::health::v1::health_client::HealthClient, - }, }; -const DEFAULT_TGIS_PORT: u16 = 8033; -const DEFAULT_NLP_PORT: u16 = 8085; -const DEFAULT_CHUNKER_PORT: u16 = 8085; -const DEFAULT_OPENAI_PORT: u16 = 8080; -const DEFAULT_DETECTOR_PORT: u16 = 8080; - const UNSUITABLE_INPUT_MESSAGE: &str = "Unsuitable input detected. \ Please check the detected entities on your input and try again \ with the unsuitable input removed."; @@ -185,27 +170,12 @@ async fn create_clients(config: &OrchestratorConfig) -> ClientMap { if let Some(generation) = &config.generation { match generation.provider { GenerationProvider::Tgis => { - let client = create_grpc_client( - DEFAULT_TGIS_PORT, - &generation.service, - GenerationServiceClient::new, - ) - .await; - let tgis_client = TgisClient::new(client); + let tgis_client = TgisClient::new(&generation.service).await; let generation_client = GenerationClient::tgis(tgis_client); clients.insert("generation".to_string(), generation_client); } GenerationProvider::Nlp => { - let client = create_grpc_client( - DEFAULT_NLP_PORT, - &generation.service, - NlpServiceClient::new, - ) - .await; - let health_client = - create_grpc_client(DEFAULT_NLP_PORT, &generation.service, HealthClient::new) - .await; - let nlp_client = NlpClient::new(client, health_client); + let nlp_client = NlpClient::new(&generation.service).await; let generation_client = GenerationClient::nlp(nlp_client); clients.insert("generation".to_string(), generation_client); } @@ -214,63 +184,63 @@ async fn create_clients(config: &OrchestratorConfig) -> ClientMap { // Create chat generation client if let Some(chat_generation) = &config.chat_generation { - let client = create_http_client(DEFAULT_OPENAI_PORT, &chat_generation.service).await; - let health_client = if let Some(health_service) = &chat_generation.health_service { - Some(create_http_client(DEFAULT_OPENAI_PORT, health_service).await) - } else { - None - }; - let openai_client = OpenAiClient::new(client, health_client); + let openai_client = OpenAiClient::new( + &chat_generation.service, + chat_generation.health_service.as_ref(), + ) + .await; clients.insert("chat_generation".to_string(), openai_client); } // Create chunker clients if let Some(chunkers) = &config.chunkers { for (chunker_id, chunker) in chunkers { - let client = create_grpc_client( - DEFAULT_CHUNKER_PORT, - &chunker.service, - ChunkersServiceClient::new, - ) - .await; - let health_client = - create_grpc_client(DEFAULT_CHUNKER_PORT, &chunker.service, HealthClient::new).await; - let chunker_client = ChunkerClient::new(client, health_client); + let chunker_client = ChunkerClient::new(&chunker.service).await; clients.insert(chunker_id.to_string(), chunker_client); } } // Create detector clients for (detector_id, detector) in &config.detectors { - let client = create_http_client(DEFAULT_DETECTOR_PORT, &detector.service).await; - let health_client = if let Some(health_service) = &detector.health_service { - Some(create_http_client(DEFAULT_DETECTOR_PORT, health_service).await) - } else { - None - }; match detector.r#type { DetectorType::TextContents => { clients.insert( detector_id.into(), - TextContentsDetectorClient::new(client, health_client), + TextContentsDetectorClient::new( + &detector.service, + detector.health_service.as_ref(), + ) + .await, ); } DetectorType::TextGeneration => { clients.insert( detector_id.into(), - TextGenerationDetectorClient::new(client, health_client), + TextGenerationDetectorClient::new( + &detector.service, + detector.health_service.as_ref(), + ) + .await, ); } DetectorType::TextChat => { clients.insert( detector_id.into(), - TextChatDetectorClient::new(client, health_client), + TextChatDetectorClient::new( + &detector.service, + detector.health_service.as_ref(), + ) + .await, ); } DetectorType::TextContextDoc => { clients.insert( detector_id.into(), - TextContextDocDetectorClient::new(client, health_client), + TextContextDocDetectorClient::new( + &detector.service, + detector.health_service.as_ref(), + ) + .await, ); } } From a7892fcc15fb5c4dca9c61b1749da7371e6c2e5a Mon Sep 17 00:00:00 2001 From: pscoro <78318122+pscoro@users.noreply.github.com> Date: Fri, 11 Oct 2024 12:50:08 -0400 Subject: [PATCH 17/50] Initial opentelemetry setup for orchestrator (#221) * otlp initial setup Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * small tweaks Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * some refactoring Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * missed fmting Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * Added telemetry ADR doc Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * review comments Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * missed nit Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * doc nits Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * rebase fix Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * docs traceparent update Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> --------- Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> --- Cargo.lock | 130 +++++++ Cargo.toml | 12 +- Dockerfile | 11 +- .../adrs/007-orchestrator-telemetry.md | 96 ++++++ src/args.rs | 214 ++++++++++++ src/lib.rs | 2 + src/main.rs | 46 +-- src/server.rs | 19 +- src/tracing_utils.rs | 322 ++++++++++++++++++ 9 files changed, 802 insertions(+), 50 deletions(-) create mode 100644 docs/architecture/adrs/007-orchestrator-telemetry.md create mode 100644 src/args.rs create mode 100644 src/tracing_utils.rs diff --git a/Cargo.lock b/Cargo.lock index 60da0bbb..c558bf93 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -608,6 +608,10 @@ dependencies = [ "hyper", "hyper-util", "mio", + "opentelemetry", + "opentelemetry-http", + "opentelemetry-otlp", + "opentelemetry_sdk", "prost", "reqwest", "rustls", @@ -622,8 +626,10 @@ dependencies = [ "tokio-stream", "tonic", "tonic-build", + "tower-http", "tower-service", "tracing", + "tracing-opentelemetry", "tracing-subscriber", "tracing-test", "url", @@ -1383,6 +1389,85 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "opentelemetry" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c365a63eec4f55b7efeceb724f1336f26a9cf3427b70e59e2cd2a5b947fba96" +dependencies = [ + "futures-core", + "futures-sink", + "js-sys", + "once_cell", + "pin-project-lite", + "thiserror", +] + +[[package]] +name = "opentelemetry-http" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad31e9de44ee3538fb9d64fe3376c1362f406162434609e79aea2a41a0af78ab" +dependencies = [ + "async-trait", + "bytes", + "http 1.1.0", + "opentelemetry", + "reqwest", +] + +[[package]] +name = "opentelemetry-otlp" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b925a602ffb916fb7421276b86756027b37ee708f9dce2dbdcc51739f07e727" +dependencies = [ + "async-trait", + "futures-core", + "http 1.1.0", + "opentelemetry", + "opentelemetry-http", + "opentelemetry-proto", + "opentelemetry_sdk", + "prost", + "thiserror", + "tokio", + "tonic", +] + +[[package]] +name = "opentelemetry-proto" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30ee9f20bff9c984511a02f082dc8ede839e4a9bf15cc2487c8d6fea5ad850d9" +dependencies = [ + "opentelemetry", + "opentelemetry_sdk", + "prost", + "tonic", +] + +[[package]] +name = "opentelemetry_sdk" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "692eac490ec80f24a17828d49b40b60f5aeaccdfe6a503f939713afd22bc28df" +dependencies = [ + "async-trait", + "futures-channel", + "futures-executor", + "futures-util", + "glob", + "once_cell", + "opentelemetry", + "percent-encoding", + "rand", + "serde_json", + "thiserror", + "tokio", + "tokio-stream", +] + [[package]] name = "overload" version = "0.1.1" @@ -2391,6 +2476,23 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-http" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" +dependencies = [ + "bitflags", + "bytes", + "http 1.1.0", + "http-body", + "http-body-util", + "pin-project-lite", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower-layer" version = "0.3.3" @@ -2447,6 +2549,24 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-opentelemetry" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9784ed4da7d921bc8df6963f8c80a0e4ce34ba6ba76668acadd3edbd985ff3b" +dependencies = [ + "js-sys", + "once_cell", + "opentelemetry", + "opentelemetry_sdk", + "smallvec", + "tracing", + "tracing-core", + "tracing-log", + "tracing-subscriber", + "web-time", +] + [[package]] name = "tracing-serde" version = "0.1.3" @@ -2724,6 +2844,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki-roots" version = "0.26.6" diff --git a/Cargo.toml b/Cargo.toml index ef13667e..2d566564 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,14 +15,21 @@ path = "src/main.rs" [dependencies] anyhow = "1.0.86" +async-trait = "0.1.81" +async-stream = "0.3.5" axum = { version = "0.7.5", features = ["json"] } axum-extra = "0.9.3" clap = { version = "4.5.15", features = ["derive", "env"] } futures = "0.3.30" ginepro = "0.8.1" +http-serde = "2.1.1" hyper = { version = "1.4.1", features = ["http1", "http2", "server"] } hyper-util = { version = "0.1.7", features = ["server-auto", "server-graceful", "tokio"] } mio = "1.0.2" +opentelemetry = { version = "0.24.0", features = ["trace", "metrics"] } +opentelemetry-http = { version = "0.13.0", features = ["reqwest"] } +opentelemetry-otlp = { version = "0.17.0", features = ["http-proto"] } +opentelemetry_sdk = { version = "0.24.1", features = ["rt-tokio", "metrics"] } prost = "0.13.1" reqwest = { version = "0.12.5", features = ["blocking", "rustls-tls", "json"] } rustls = {version = "0.23.12", default-features = false, features = ["std"]} @@ -36,14 +43,13 @@ tokio = { version = "1.39.2", features = ["rt", "rt-multi-thread", "parking_lot" tokio-rustls = { version = "0.26.0" } tokio-stream = { version = "0.1.15", features = ["sync"] } tonic = { version = "0.12.1", features = ["tls", "tls-roots", "tls-webpki-roots"] } +tower-http = { version = "0.5.2", features = ["trace"] } tower-service = "0.3" tracing = "0.1.40" +tracing-opentelemetry = "0.25.0" tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] } url = "2.5.2" uuid = { version = "1.10.0", features = ["v4", "fast-rng"] } -async-trait = "0.1.81" -async-stream = "0.3.5" -http-serde = "2.1.1" [build-dependencies] tonic-build = "0.12.1" diff --git a/Dockerfile b/Dockerfile index 6caee35c..d1913a4c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,7 @@ ARG UBI_MINIMAL_BASE_IMAGE=registry.access.redhat.com/ubi9/ubi-minimal ARG UBI_BASE_IMAGE_TAG=latest ARG PROTOC_VERSION=26.0 +ARG CONFIG_FILE=config/config.yaml ## Rust builder ################################################################ # Specific debian version so that compatible glibc version is used @@ -23,10 +24,10 @@ RUN rustup component add rustfmt ## Orchestrator builder ######################################################### FROM rust-builder as fms-guardrails-orchestr8-builder -COPY build.rs *.toml LICENSE /app -COPY config/ /app/config -COPY protos/ /app/protos -COPY src/ /app/src +COPY build.rs *.toml LICENSE /app/ +COPY ${CONFIG_FILE} /app/config/config.yaml +COPY protos/ /app/protos/ +COPY src/ /app/src/ WORKDIR /app @@ -50,7 +51,7 @@ RUN cargo fmt --check FROM ${UBI_MINIMAL_BASE_IMAGE}:${UBI_BASE_IMAGE_TAG} as fms-guardrails-orchestr8-release COPY --from=fms-guardrails-orchestr8-builder /app/bin/ /app/bin/ -COPY config /app/config +COPY ${CONFIG_FILE} /app/config/config.yaml RUN microdnf install -y --disableplugin=subscription-manager shadow-utils compat-openssl11 && \ microdnf clean all --disableplugin=subscription-manager diff --git a/docs/architecture/adrs/007-orchestrator-telemetry.md b/docs/architecture/adrs/007-orchestrator-telemetry.md new file mode 100644 index 00000000..00690512 --- /dev/null +++ b/docs/architecture/adrs/007-orchestrator-telemetry.md @@ -0,0 +1,96 @@ +# ADR 007: Orchestrator Telemetry + +The guardrails orchestrator uses [OpenTelemetry](https://opentelemetry.io/) to collect and export telemetry data (traces, metrics, and logs). + +The orchestrator needs to collect telemetry data for monitoring and observability. It also needs to be able to trace +spans for incoming requests and across client requests to configured detectors, chunkers, and generation services and +aggregate detailed traces, metrics, and logs that can be monitored from a variety of observability backends. + +## Decision + +### OpenTelemetry and `tracing` + +The orchestrator and client services will make use of the OpenTelemetry SDK and the [OpenTelemetry Protocol (OTLP)](https://opentelemetry.io/docs/specs/otel/protocol/) +for consolidating and collecting telemetry data across services. The orchestrator will be responsible for collecting +telemetry data throughout the lifetime of a request using the `tracing` crate, which is the de facto choice for logging +and tracing for OpenTelemetry in Rust, and exporting it through the OTLP exporter if configured. The OTLP exporter will +send telemetry data to a gRPC or HTTP endpoint that can be configured to point to a running OpenTelemetry (OTEL) collector. +Similarly, detectors should also be able to collect and export telemetry through OTLP to the same OTEL collector. +From the OTEL collector, the telemetry data can then be exported to multiple backends. The OTEL collector and +any observability backends can all be configured alongside the orchestrator and detectors in a deployment. + +### Instrumentation +An incoming request to the orchestrator will initialize a new trace, therefore a trace-id and request should be in +one-to-one correspondence. All important functions throughout the control flow of handling a request in the orchestrator +will be instrumented with the `#[tracing::instrument]` attribute macro above the function definition. This will create +and enter a span for each function call and add it to the trace of the request. Here, important functions refers to any +functions that perform important business logic that may incur significant latency, including all the handler functions +for incoming and outgoing requests. It is up to the discretion of the developer to determine what functions are +"significant" enough to indicate a new span in the trace, but adding a new tracing span can always trivially be done by +just adding the instrument macro. + +### Metrics +The orchestrator will aggregate metrics regarding the requests it has received/handled, and annotate the metrics with +span attributes allowing for detailed filtering and monitoring. The metrics will be exported through the OTLP exporter +through the metrics provider. Traces exported through the traces provider can also have R.E.D. (request, error and +duration) metrics attached to them implicitly by the OTEL collector using the `spanmetrics` connector. Both the OTLP +metrics and the `spanmetrics` metrics can be exported to configured metrics backends like Prometheus or Grafana. +The orchestrator will handle a variety of useful metrics such as counters and histograms for received/handled +successful/failed requests, request and stream durations, and server/client errors. Traces and metrics will also relate +incoming orchestrator requests to respective client requests/responses, and collect more business specific metrics +e.g. regarding the outcome of running detection or generation. + +### Configuration +The orchestrator will expose CLI args/env variables for configuring the OTLP exporter: +- `OTEL_EXPORTER_OTLP_PROTOCOL=grpc|http` to set the protocol for all the OTLP endpoints + - `OTEL_EXPORTER_OTLP_TRACES_PROTOCOL` and `OTEL_EXPORTER_OTLP_METRICS_PROTOCOL` to set/override the protocol for + traces or metrics. +- `--otlp-endpoint, OTEL_EXPORTER_OTLP_ENDPOINT` to set the OTLP endpoint + - defaults: gRPC `localhost:4317` and HTTP `localhost:4318` + - `--otlp-traces-endpoint, OTEL_EXPORTER_OTLP_TRACES_ENDPOINT` and `--otlp-metrics-endpoint, + OTEL_EXPORTER_OTLP_METRICS_ENDPOINT` to set/override the endpoint for traces or metrics + - default to `localhost:4317` for gRPC for all data types, and `localhost:4318/v1/traces`, or `metrics`, for HTTP +- `--otlp-export, OTLP_EXPORT=traces,metrics` to specify a list of which data types to export to the OTLP exporters, separated by a + comma. The possible values are traces, metrics, or both. The OTLP standard specifies three data types (`traces`, + `metrics`, `logs`) but since we use the recommended `tracing` crate for logging, we can export logs as traces and + not use the separate (more experimental) logging export pipeline. +- `RUST_LOG=error|warn|info|debug|trace` to set the Rust log level. +- `--log-format, LOG_FORMAT=full|compact|json|pretty` to set the logging format for logs printed to stdout. All logs collected as + traces by OTLP will just be structured traces, this argument is specifically for stdout. Default is `full`. +- `--quiet, -q` to silence logging to stdout. If `OTLP_EXPORT=traces` is still provided, all logs can still be viewed + as traces from an observability backend. + +### Cross-service tracing +The orchestrator and client services will be able to consolidate telemetry and share observability through a common +configuration and backends. This will be made possible through the use of the OTLP standard as well as through the +propagation of the trace context through requests across services using the standardized `traceparent` header. The +orchestrator will be expected to initialize a new trace for an incoming request and pass `traceparent` headers +corresponding to this trace to any requests outgoing to clients, and similarly, the orchestrator will expect the client +to provide a `traceparent` header in the response. The orchestrator will not propagate the `traceparent` to outgoing +responses back to the end user (or expect `traceparent` in incoming requests) for security reasons. + +## Status + +Proposed + +## Consequences + +- The orchestrator and client services have a common standard to conform to for telemetry, allowing for traceability + across different services. There does not exist any other attempts at telemetry standardization that is as widely + accepted as OpenTelemetry, or have the same level of support from existing observability and monitoring services. +- The deployment of the orchestrator must be configured with telemetry service(s) listening for telemetry exported on + the specified endpoint(s). An [OTEL collector](https://opentelemetry.io/docs/collector/) service can be used to + collect and propagate the telemetry data, or the export endpoint(s) can be listened to directly by any backend that + supports OTLP (e.g. Jaeger). +- The orchestrator and client services do not need to be concerned with specific observability backends, the OTEL + collector and OTLP standard can be used to export telemetry data to a variety of backends including Jaeger, + Prometheus, Grafana, and Instana, as well to OpenShift natively through the OpenTelemetryCollector CRD. +- Using the `tracing` crate in Rust for logging will treat logs as traces, allowing the orchestrator to export logs + through the trace provider (with OTLP exporter), simplifying the implementation and avoiding use of the logging + provider which is still considered experimental in many contexts (it exists for compatibility with non `tracing` + logging libraries). +- For stdout, the new `--log-format` and `--quiet` arguments add more configurability to format or silence logging. +- The integration of the OpenTelemetry API/SDK into the stack is not trivial, and the OpenTelemetry crates will incur + additional compile time to the orchestrator. +- The OpenTelemetry API/SDK and OTLP standard are new and still evolving, and the orchestrator will need to keep up + with changes in the OpenTelemetry ecosystem, there could be occasional breaking changes that will need addressing. \ No newline at end of file diff --git a/src/args.rs b/src/args.rs new file mode 100644 index 00000000..b5b38f03 --- /dev/null +++ b/src/args.rs @@ -0,0 +1,214 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +use std::{fmt::Display, path::PathBuf}; + +use clap::Parser; +use tracing::{error, warn}; + +#[derive(Parser, Debug, Clone)] +#[clap(author, version, about, long_about = None)] +pub struct Args { + #[clap(default_value = "8033", long, env)] + pub http_port: u16, + #[clap(default_value = "8034", long, env)] + pub health_http_port: u16, + #[clap( + default_value = "config/config.yaml", + long, + env = "ORCHESTRATOR_CONFIG" + )] + pub config_path: PathBuf, + #[clap(long, env)] + pub tls_cert_path: Option, + #[clap(long, env)] + pub tls_key_path: Option, + #[clap(long, env)] + pub tls_client_ca_cert_path: Option, + #[clap(default_value = "false", long, env)] + pub start_up_health_check: bool, + #[clap(long, env, value_delimiter = ',')] + pub otlp_export: Vec, + #[clap(default_value_t = LogFormat::default(), long, env)] + pub log_format: LogFormat, + #[clap(default_value_t = false, long, short, env)] + pub quiet: bool, + #[clap(default_value = "fms_guardrails_orchestr8", long, env)] + pub otlp_service_name: String, + #[clap(long, env = "OTEL_EXPORTER_OTLP_ENDPOINT")] + pub otlp_endpoint: Option, + #[clap(long, env = "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT")] + pub otlp_traces_endpoint: Option, + #[clap(long, env = "OTEL_EXPORTER_OTLP_METRICS_ENDPOINT")] + pub otlp_metrics_endpoint: Option, + #[clap( + default_value_t = OtlpProtocol::Grpc, + long, + env = "OTEL_EXPORTER_OTLP_PROTOCOL" + )] + pub otlp_protocol: OtlpProtocol, + #[clap(long, env = "OTEL_EXPORTER_OTLP_TRACES_PROTOCOL")] + pub otlp_traces_protocol: Option, + #[clap(long, env = "OTEL_EXPORTER_OTLP_METRICS_PROTOCOL")] + pub otlp_metrics_protocol: Option, + // TODO: Add timeout and header OTLP variables +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum OtlpExport { + Traces, + Metrics, +} + +impl Display for OtlpExport { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + OtlpExport::Traces => write!(f, "traces"), + OtlpExport::Metrics => write!(f, "metrics"), + } + } +} + +impl From for OtlpExport { + fn from(s: String) -> Self { + match s.to_lowercase().as_str() { + "traces" => OtlpExport::Traces, + "metrics" => OtlpExport::Metrics, + _ => panic!( + "Invalid OTLP export type {}, orchestrator only supports exporting traces and metrics via OTLP", + s + ), + } + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub enum OtlpProtocol { + #[default] + Grpc, + Http, +} + +impl Display for OtlpProtocol { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + OtlpProtocol::Grpc => write!(f, "grpc"), + OtlpProtocol::Http => write!(f, "http"), + } + } +} + +impl From for OtlpProtocol { + fn from(s: String) -> Self { + match s.to_lowercase().as_str() { + "grpc" => OtlpProtocol::Grpc, + "http" => OtlpProtocol::Http, + _ => { + error!( + "Invalid OTLP protocol {}, defaulting to {}", + s, + OtlpProtocol::default() + ); + OtlpProtocol::default() + } + } + } +} + +impl OtlpProtocol { + pub fn default_endpoint(&self) -> &str { + match self { + OtlpProtocol::Grpc => "http://localhost:4317", + OtlpProtocol::Http => "http://localhost:4318", + } + } +} + +#[derive(Debug, Clone, Copy, Default, PartialEq)] +pub enum LogFormat { + #[default] + Full, + Compact, + Pretty, + JSON, +} + +impl Display for LogFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + LogFormat::Full => write!(f, "full"), + LogFormat::Compact => write!(f, "compact"), + LogFormat::Pretty => write!(f, "pretty"), + LogFormat::JSON => write!(f, "json"), + } + } +} + +impl From for LogFormat { + fn from(s: String) -> Self { + match s.to_lowercase().as_str() { + "full" => LogFormat::Full, + "compact" => LogFormat::Compact, + "pretty" => LogFormat::Pretty, + "json" => LogFormat::JSON, + _ => { + warn!( + "Invalid trace format {}, defaulting to {}", + s, + LogFormat::default() + ); + LogFormat::default() + } + } + } +} + +#[derive(Debug, Clone)] +pub struct TracingConfig { + pub service_name: String, + pub traces: Option<(OtlpProtocol, String)>, + pub metrics: Option<(OtlpProtocol, String)>, + pub log_format: LogFormat, + pub quiet: bool, +} + +impl From for TracingConfig { + fn from(args: Args) -> Self { + let otlp_protocol = args.otlp_protocol; + let otlp_endpoint = args + .otlp_endpoint + .unwrap_or(otlp_protocol.default_endpoint().to_string()); + let otlp_traces_endpoint = args.otlp_traces_endpoint.unwrap_or(otlp_endpoint.clone()); + let otlp_metrics_endpoint = args.otlp_metrics_endpoint.unwrap_or(otlp_endpoint.clone()); + let otlp_traces_protocol = args.otlp_traces_protocol.unwrap_or(otlp_protocol); + let otlp_metrics_protocol = args.otlp_metrics_protocol.unwrap_or(otlp_protocol); + + TracingConfig { + service_name: args.otlp_service_name, + traces: match args.otlp_export.contains(&OtlpExport::Traces) { + true => Some((otlp_traces_protocol, otlp_traces_endpoint)), + false => None, + }, + metrics: match args.otlp_export.contains(&OtlpExport::Metrics) { + true => Some((otlp_metrics_protocol, otlp_metrics_endpoint)), + false => None, + }, + log_format: args.log_format, + quiet: args.quiet, + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 4ba228cb..efc18bc0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,7 @@ #![allow(clippy::iter_kv_map, clippy::enum_variant_names, async_fn_in_trait)] +pub mod args; mod clients; pub mod config; pub mod health; @@ -24,3 +25,4 @@ mod models; pub mod orchestrator; mod pb; pub mod server; +pub mod tracing_utils; diff --git a/src/main.rs b/src/main.rs index 21a5c5c0..0cc174c2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,39 +15,12 @@ */ -use std::{ - net::{IpAddr, Ipv4Addr, SocketAddr}, - path::PathBuf, -}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use clap::Parser; -use fms_guardrails_orchestr8::{config::OrchestratorConfig, orchestrator::Orchestrator, server}; -use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; - -#[derive(Parser, Debug)] -#[clap(author, version, about, long_about = None)] -struct Args { - #[clap(default_value = "8033", long, env)] - http_port: u16, - #[clap(default_value = "8034", long, env)] - health_http_port: u16, - #[clap(long, env)] - json_output: bool, - #[clap( - default_value = "config/config.yaml", - long, - env = "ORCHESTRATOR_CONFIG" - )] - config_path: PathBuf, - #[clap(long, env)] - tls_cert_path: Option, - #[clap(long, env)] - tls_key_path: Option, - #[clap(long, env)] - tls_client_ca_cert_path: Option, - #[clap(default_value = "false", long, env)] - start_up_health_check: bool, -} +use fms_guardrails_orchestr8::{ + args::Args, config::OrchestratorConfig, orchestrator::Orchestrator, server, tracing_utils, +}; fn main() -> Result<(), anyhow::Error> { rustls::crypto::aws_lc_rs::default_provider() @@ -62,14 +35,6 @@ fn main() -> Result<(), anyhow::Error> { panic!("tls: cannot provide client ca cert without keypair") } - let filter = EnvFilter::try_from_default_env() - .unwrap_or(EnvFilter::new("INFO")) - .add_directive("ginepro=info".parse().unwrap()); - tracing_subscriber::registry() - .with(filter) - .with(tracing_subscriber::fmt::layer()) - .init(); - let http_addr: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), args.http_port); let health_http_addr: SocketAddr = @@ -81,6 +46,7 @@ fn main() -> Result<(), anyhow::Error> { .build() .unwrap() .block_on(async { + let trace_shutdown = tracing_utils::init_tracing(args.clone().into())?; let config = OrchestratorConfig::load(args.config_path).await?; let orchestrator = Orchestrator::new(config, args.start_up_health_check).await?; @@ -93,6 +59,6 @@ fn main() -> Result<(), anyhow::Error> { orchestrator, ) .await?; - Ok(()) + Ok(trace_shutdown()?) }) } diff --git a/src/server.rs b/src/server.rs index 0ad70a3e..ff83c73d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -43,8 +43,9 @@ use hyper_util::rt::{TokioExecutor, TokioIo}; use rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig}; use tokio::{net::TcpListener, signal}; use tokio_rustls::TlsAcceptor; +use tower_http::trace::TraceLayer; use tower_service::Service; -use tracing::{debug, error, info, warn}; +use tracing::{debug, error, info, instrument, warn}; use uuid::Uuid; use webpki::types::{CertificateDer, PrivateKeyDer}; @@ -55,6 +56,7 @@ use crate::{ GenerationWithDetectionTask, Orchestrator, StreamingClassificationWithGenTask, TextContentDetectionTask, }, + tracing_utils, }; const API_PREFIX: &str = r#"/api/v1/task"#; @@ -177,7 +179,14 @@ pub async fn run( &format!("{}/detection/generated", TEXT_API_PREFIX), post(detect_generated), ) - .with_state(shared_state); + .with_state(shared_state) + .layer( + TraceLayer::new_for_http() + .make_span_with(tracing_utils::incoming_request_span) + .on_request(tracing_utils::on_incoming_request) + .on_response(tracing_utils::on_outgoing_response) + .on_eos(tracing_utils::on_outgoing_eos), + ); // (2c) Generate main guardrails server handle based on whether TLS is needed let listener: TcpListener = TcpListener::bind(&http_addr) @@ -299,6 +308,7 @@ async fn info( Ok(Json(InfoResponse { services })) } +#[instrument(skip_all, fields(model_id = ?request.model_id))] async fn classification_with_gen( State(state): State>, headers: HeaderMap, @@ -318,6 +328,7 @@ async fn classification_with_gen( } } +#[instrument(skip_all, fields(model_id = ?request.model_id))] async fn generation_with_detection( State(state): State>, headers: HeaderMap, @@ -340,6 +351,7 @@ async fn generation_with_detection( } } +#[instrument(skip_all, fields(model_id = ?request.model_id))] async fn stream_classification_with_gen( State(state): State>, headers: HeaderMap, @@ -382,6 +394,7 @@ async fn stream_classification_with_gen( Sse::new(event_stream).keep_alive(KeepAlive::default()) } +#[instrument(skip_all)] async fn detection_content( State(state): State>, headers: HeaderMap, @@ -397,6 +410,7 @@ async fn detection_content( } } +#[instrument(skip_all)] async fn detect_context_documents( State(state): State>, headers: HeaderMap, @@ -416,6 +430,7 @@ async fn detect_context_documents( } } +#[instrument(skip_all)] async fn detect_generated( State(state): State>, headers: HeaderMap, diff --git a/src/tracing_utils.rs b/src/tracing_utils.rs new file mode 100644 index 00000000..b2a18ab1 --- /dev/null +++ b/src/tracing_utils.rs @@ -0,0 +1,322 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +use std::time::Duration; + +use axum::extract::Request; +use axum::http::HeaderMap; +use axum::response::Response; +use opentelemetry::{ + global, + metrics::MetricsError, + trace::{TraceContextExt, TraceError, TracerProvider}, + KeyValue, +}; +use opentelemetry_otlp::WithExportConfig; +use opentelemetry_sdk::{ + metrics::{ + reader::{DefaultAggregationSelector, DefaultTemporalitySelector}, + SdkMeterProvider, + }, + runtime, + trace::{Config, Sampler}, + Resource, +}; +use tracing::{error, info, info_span, Span}; +use tracing_opentelemetry::{MetricsLayer, OpenTelemetrySpanExt}; +use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Layer}; + +use crate::args::{LogFormat, OtlpProtocol, TracingConfig}; + +#[derive(Debug, thiserror::Error)] +pub enum TracingError { + #[error("Error from tracing provider: {0}")] + TraceError(#[from] TraceError), + #[error("Error from metrics provider: {0}")] + MetricsError(#[from] MetricsError), +} + +/// Initializes an OpenTelemetry tracer provider with an OTLP export pipeline based on the +/// provided config. +fn init_tracer_provider( + otlp_export_config: TracingConfig, +) -> Result, TracingError> { + if let Some((protocol, endpoint)) = otlp_export_config.traces { + Ok(Some( + match protocol { + OtlpProtocol::Grpc => opentelemetry_otlp::new_pipeline().tracing().with_exporter( + opentelemetry_otlp::new_exporter() + .tonic() + .with_endpoint(endpoint) + .with_timeout(Duration::from_secs(3)), + ), + OtlpProtocol::Http => opentelemetry_otlp::new_pipeline().tracing().with_exporter( + opentelemetry_otlp::new_exporter() + .http() + .with_http_client(reqwest::Client::new()) + .with_endpoint(endpoint) + .with_timeout(Duration::from_secs(3)), + ), + } + .with_trace_config( + Config::default() + .with_resource(Resource::new(vec![KeyValue::new( + "service.name", + otlp_export_config.service_name, + )])) + .with_sampler(Sampler::AlwaysOn), + ) + .install_batch(runtime::Tokio)?, + )) + } else { + Ok(None) + } +} + +/// Initializes an OpenTelemetry meter provider with an OTLP export pipeline based on the +/// provided config. +fn init_meter_provider( + otlp_export_config: TracingConfig, +) -> Result, TracingError> { + if let Some((protocol, endpoint)) = otlp_export_config.metrics { + Ok(Some( + match protocol { + OtlpProtocol::Grpc => opentelemetry_otlp::new_pipeline() + .metrics(runtime::Tokio) + .with_exporter( + opentelemetry_otlp::new_exporter() + .tonic() + .with_endpoint(endpoint), + ), + OtlpProtocol::Http => opentelemetry_otlp::new_pipeline() + .metrics(runtime::Tokio) + .with_exporter( + opentelemetry_otlp::new_exporter() + .http() + .with_http_client(reqwest::Client::new()) + .with_endpoint(endpoint), + ), + } + .with_resource(Resource::new(vec![KeyValue::new( + "service.name", + otlp_export_config.service_name, + )])) + .with_timeout(Duration::from_secs(10)) + .with_period(Duration::from_secs(3)) + .with_aggregation_selector(DefaultAggregationSelector::new()) + .with_temporality_selector(DefaultTemporalitySelector::new()) + .build()?, + )) + } else { + Ok(None) + } +} + +/// Initializes tracing for the orchestrator using the OpenTelemetry API/SDK and the `tracing` +/// crate. What telemetry is exported and to where is determined based on the provided config +pub fn init_tracing( + tracing_config: TracingConfig, +) -> Result Result<(), TracingError>, TracingError> { + let mut layers = Vec::new(); + + // TODO: Find a better way to only propagate errors from other crates + let filter = EnvFilter::try_from_default_env() + .unwrap_or(EnvFilter::new("INFO")) + .add_directive("ginepro=info".parse().unwrap()) + .add_directive("hyper=error".parse().unwrap()) + .add_directive("h2=error".parse().unwrap()) + .add_directive("trust_dns_resolver=error".parse().unwrap()) + .add_directive("trust_dns_proto=error".parse().unwrap()) + .add_directive("tower=error".parse().unwrap()) + .add_directive("tonic=error".parse().unwrap()) + .add_directive("reqwest=error".parse().unwrap()); + + // Set up tracing layer with OTLP exporter + let trace_provider = init_tracer_provider(tracing_config.clone())?; + if let Some(tracer_provider) = trace_provider.clone() { + global::set_tracer_provider(tracer_provider.clone()); + layers.push( + tracing_opentelemetry::layer() + .with_tracer(tracer_provider.tracer(tracing_config.service_name.clone())) + .boxed(), + ); + } + + // Set up metrics layer with OTLP exporter + let meter_provider = init_meter_provider(tracing_config.clone())?; + if let Some(meter_provider) = meter_provider.clone() { + global::set_meter_provider(meter_provider.clone()); + layers.push(MetricsLayer::new(meter_provider).boxed()); + } + + // Set up formatted layer for logging to stdout + // Because we use the `tracing` crate for logging, all logs are traces and will be exported + // to OTLP if `--otlp-export=traces` is set. + if !tracing_config.quiet { + match tracing_config.log_format { + LogFormat::Full => layers.push(tracing_subscriber::fmt::layer().boxed()), + LogFormat::Compact => layers.push(tracing_subscriber::fmt::layer().compact().boxed()), + LogFormat::Pretty => layers.push(tracing_subscriber::fmt::layer().pretty().boxed()), + LogFormat::JSON => layers.push( + tracing_subscriber::fmt::layer() + .json() + .flatten_event(true) + .boxed(), + ), + } + } + + let subscriber = tracing_subscriber::registry().with(filter).with(layers); + tracing::subscriber::set_global_default(subscriber).unwrap(); + + if let Some(traces) = tracing_config.traces { + info!( + "OTLP tracing enabled: Exporting {} to {}", + traces.0, traces.1 + ); + } else { + info!("OTLP traces export disabled") + } + + if let Some(metrics) = tracing_config.metrics { + info!( + "OTLP metrics enabled: Exporting {} to {}", + metrics.0, metrics.1 + ); + } else { + info!("OTLP metrics export disabled") + } + + if !tracing_config.quiet { + info!( + "Stdout logging enabled with format {}", + tracing_config.log_format + ); + } else { + info!("Stdout logging disabled"); // This will only be visible in traces + } + + Ok(move || { + global::shutdown_tracer_provider(); + if let Some(meter_provider) = meter_provider { + meter_provider + .shutdown() + .map_err(TracingError::MetricsError)?; + } + Ok(()) + }) +} + +pub fn incoming_request_span(request: &Request) -> Span { + info_span!( + "incoming_orchestrator_http_request", + request_method = request.method().to_string(), + request_path = request.uri().path().to_string(), + response_status_code = tracing::field::Empty, + request_duration_ms = tracing::field::Empty, + stream_response = tracing::field::Empty, + stream_response_event_count = tracing::field::Empty, + stream_response_error_count = tracing::field::Empty, + stream_response_duration_ms = tracing::field::Empty, + ) +} + +pub fn on_incoming_request(request: &Request, span: &Span) { + let _guard = span.enter(); + info!( + "incoming request to {} {} with trace_id {}", + request.method(), + request.uri().path(), + span.context().span().span_context().trace_id().to_string() + ); + info!( + monotonic_counter.incoming_request_count = 1, + request_method = request.method().as_str(), + request_path = request.uri().path() + ); +} + +pub fn on_outgoing_response(response: &Response, latency: Duration, span: &Span) { + let _guard = span.enter(); + span.record("response_status_code", response.status().as_u16()); + span.record("request_duration_ms", latency.as_millis()); + + info!( + "response {} for request with with trace_id {} generated in {} ms", + &response.status(), + span.context().span().span_context().trace_id().to_string(), + latency.as_millis() + ); + + // On every response + info!( + monotonic_counter.handled_request_count = 1, + response_status = response.status().as_u16(), + request_duration = latency.as_millis() + ); + info!( + histogram.service_request_duration = latency.as_millis(), + response_status = response.status().as_u16() + ); + + if response.status().is_server_error() { + // On every server error (HTTP 5xx) response + info!( + monotonic_counter.server_error_response_count = 1, + response_status = response.status().as_u16(), + request_duration = latency.as_millis() + ); + } else if response.status().is_client_error() { + // On every client error (HTTP 4xx) response + info!( + monotonic_counter.client_error_response_count = 1, + response_status = response.status().as_u16(), + request_duration = latency.as_millis() + ); + } else if response.status().is_success() { + // On every successful (HTTP 2xx) response + info!( + monotonic_counter.success_response_count = 1, + response_status = response.status().as_u16(), + request_duration = latency.as_millis() + ); + } else { + error!( + "unexpected response status code: {}", + response.status().as_u16() + ); + } +} + +pub fn on_outgoing_eos(trailers: Option<&HeaderMap>, stream_duration: Duration, span: &Span) { + let _guard = span.enter(); + + span.record("stream_response", true); + span.record("stream_response_duration_ms", stream_duration.as_millis()); + + info!( + "stream response for request with trace_id {} closed after {} ms with trailers: {:?}", + span.context().span().span_context().trace_id().to_string(), + stream_duration.as_millis(), + trailers + ); + info!( + monotonic_counter.service_stream_response_count = 1, + stream_duration = stream_duration.as_millis() + ); + info!(monotonic_histogram.service_stream_response_duration = stream_duration.as_millis()); +} From 462d7ff4eb773d0b4cd126ffdef29e8b39cbbf4c Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Mon, 14 Oct 2024 16:01:45 -0500 Subject: [PATCH 18/50] :bug: Fix whole_doc_chunker missing from client list Signed-off-by: gkumbhat --- src/orchestrator.rs | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 82939d18..3c35544d 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -29,15 +29,15 @@ use uuid::Uuid; use crate::{ clients::{ + chunker::{ChunkerClient, DEFAULT_MODEL_ID as CHUNKER_DEFAULT_MODEL_ID}, detector::{ text_context_doc::ContextType, TextChatDetectorClient, TextContextDocDetectorClient, TextGenerationDetectorClient, }, openai::OpenAiClient, - ChunkerClient, ClientMap, GenerationClient, NlpClient, TextContentsDetectorClient, - TgisClient, + ClientMap, GenerationClient, NlpClient, TextContentsDetectorClient, TgisClient, }, - config::{DetectorType, GenerationProvider, OrchestratorConfig}, + config::{DetectorType, GenerationProvider, OrchestratorConfig, ServiceConfig}, health::HealthCheckCache, models::{ ContextDocsHttpRequest, DetectionOnGeneratedHttpRequest, DetectorParams, @@ -199,6 +199,11 @@ async fn create_clients(config: &OrchestratorConfig) -> ClientMap { clients.insert(chunker_id.to_string(), chunker_client); } } + // Insert default chunker + clients.insert( + CHUNKER_DEFAULT_MODEL_ID.to_string(), + ChunkerClient::new(&ServiceConfig::default()).await, + ); // Create detector clients for (detector_id, detector) in &config.detectors { From 96f0933f12e514337d4e3b69b95e9577bf988553 Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Mon, 14 Oct 2024 18:59:50 -0500 Subject: [PATCH 19/50] :recycle: Refactor to implement whole_doc_chunker separately Signed-off-by: gkumbhat --- src/clients/chunker.rs | 48 ++++++++++------------------------- src/orchestrator.rs | 9 ++----- src/orchestrator/streaming.rs | 30 +++++++++++++++++----- src/orchestrator/unary.rs | 26 ++++++++++++------- 4 files changed, 57 insertions(+), 56 deletions(-) diff --git a/src/clients/chunker.rs b/src/clients/chunker.rs index 8f21f821..c6cc568d 100644 --- a/src/clients/chunker.rs +++ b/src/clients/chunker.rs @@ -20,10 +20,7 @@ use std::pin::Pin; use async_trait::async_trait; use futures::{Future, Stream, StreamExt, TryStreamExt}; use ginepro::LoadBalancedChannel; -use tokio::sync::mpsc; -use tokio_stream::wrappers::ReceiverStream; use tonic::{Code, Request, Response, Status, Streaming}; -use tracing::info; use super::{create_grpc_client, errors::grpc_to_http_code, BoxStream, Client, Error}; use crate::{ @@ -70,11 +67,6 @@ impl ChunkerClient { model_id: &str, request: ChunkerTokenizationTaskRequest, ) -> Result { - // Handle "default" separately first - if model_id == DEFAULT_MODEL_ID { - info!("Using default whole doc chunker"); - return Ok(tokenize_whole_doc(request)); - } let mut client = self.client.clone(); let request = request_with_model_id(request, model_id); Ok(client @@ -88,30 +80,18 @@ impl ChunkerClient { model_id: &str, request_stream: BoxStream, ) -> Result>, Error> { - let response_stream = if model_id == DEFAULT_MODEL_ID { - info!("Using default whole doc chunker"); - let (response_tx, response_rx) = mpsc::channel(1); - // Spawn task to collect input stream - tokio::spawn(async move { - // NOTE: this will not resolve until the input stream is closed - let response = tokenize_whole_doc_stream(request_stream).await; - let _ = response_tx.send(response).await; - }); - ReceiverStream::new(response_rx).boxed() - } else { - let mut client = self.client.clone(); - let request = request_with_model_id(request_stream, model_id); - // NOTE: this is an ugly workaround to avoid bogus higher-ranked lifetime errors. - // https://github.com/rust-lang/rust/issues/110338 - let response_stream_fut: Pin< - Box + Send>, - > = Box::pin(client.bidi_streaming_chunker_tokenization_task_predict(request)); - response_stream_fut - .await? - .into_inner() - .map_err(Into::into) - .boxed() - }; + let mut client = self.client.clone(); + let request = request_with_model_id(request_stream, model_id); + // NOTE: this is an ugly workaround to avoid bogus higher-ranked lifetime errors. + // https://github.com/rust-lang/rust/issues/110338 + let response_stream_fut: Pin + Send>> = + Box::pin(client.bidi_streaming_chunker_tokenization_task_predict(request)); + let response_stream = response_stream_fut + .await? + .into_inner() + .map_err(Into::into) + .boxed(); + Ok(response_stream) } } @@ -157,7 +137,7 @@ fn request_with_model_id(request: T, model_id: &str) -> Request { } /// Unary tokenization result of the entire doc -fn tokenize_whole_doc(request: ChunkerTokenizationTaskRequest) -> TokenizationResults { +pub fn tokenize_whole_doc(request: ChunkerTokenizationTaskRequest) -> TokenizationResults { let codepoint_count = request.text.chars().count() as i64; TokenizationResults { results: vec![Token { @@ -170,7 +150,7 @@ fn tokenize_whole_doc(request: ChunkerTokenizationTaskRequest) -> TokenizationRe } /// Streaming tokenization result for the entire doc stream -async fn tokenize_whole_doc_stream( +pub async fn tokenize_whole_doc_stream( request: impl Stream, ) -> Result { let (text, index_vec): (String, Vec) = request diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 3c35544d..967ca427 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -29,7 +29,7 @@ use uuid::Uuid; use crate::{ clients::{ - chunker::{ChunkerClient, DEFAULT_MODEL_ID as CHUNKER_DEFAULT_MODEL_ID}, + chunker::ChunkerClient, detector::{ text_context_doc::ContextType, TextChatDetectorClient, TextContextDocDetectorClient, TextGenerationDetectorClient, @@ -37,7 +37,7 @@ use crate::{ openai::OpenAiClient, ClientMap, GenerationClient, NlpClient, TextContentsDetectorClient, TgisClient, }, - config::{DetectorType, GenerationProvider, OrchestratorConfig, ServiceConfig}, + config::{DetectorType, GenerationProvider, OrchestratorConfig}, health::HealthCheckCache, models::{ ContextDocsHttpRequest, DetectionOnGeneratedHttpRequest, DetectorParams, @@ -199,11 +199,6 @@ async fn create_clients(config: &OrchestratorConfig) -> ClientMap { clients.insert(chunker_id.to_string(), chunker_client); } } - // Insert default chunker - clients.insert( - CHUNKER_DEFAULT_MODEL_ID.to_string(), - ChunkerClient::new(&ServiceConfig::default()).await, - ); // Create detector clients for (detector_id, detector) in &config.detectors { diff --git a/src/orchestrator/streaming.rs b/src/orchestrator/streaming.rs index b8697df1..d671407f 100644 --- a/src/orchestrator/streaming.rs +++ b/src/orchestrator/streaming.rs @@ -29,8 +29,11 @@ use tracing::{debug, error, info, instrument}; use super::{get_chunker_ids, Context, Error, Orchestrator, StreamingClassificationWithGenTask}; use crate::{ clients::{ - detector::ContentAnalysisRequest, ChunkerClient, GenerationClient, - TextContentsDetectorClient, + chunker::{ + tokenize_whole_doc_stream, ChunkerClient, DEFAULT_MODEL_ID as CHUNKER_DEFAULT_MODEL_ID, + }, + detector::ContentAnalysisRequest, + GenerationClient, TextContentsDetectorClient, }, models::{ ClassifiedGeneratedTextStreamResult, DetectorParams, GuardrailsTextGenerationParameters, @@ -455,10 +458,25 @@ async fn chunk_broadcast_task( .boxed(); debug!(%chunker_id, "creating chunker output stream"); let id = chunker_id.clone(); // workaround for StreamExt::map_err - let client = ctx.clients.get_as::(&chunker_id).unwrap(); - let mut output_stream = client - .bidi_streaming_tokenization_task_predict(&chunker_id, input_stream) - .await + + let response_stream = if chunker_id == CHUNKER_DEFAULT_MODEL_ID { + info!("Using default whole doc chunker"); + let (response_tx, response_rx) = mpsc::channel(1); + // Spawn task to collect input stream + tokio::spawn(async move { + // NOTE: this will not resolve until the input stream is closed + let response = tokenize_whole_doc_stream(input_stream).await; + let _ = response_tx.send(response).await; + }); + Ok(ReceiverStream::new(response_rx).boxed()) + } else { + let client = ctx.clients.get_as::(&chunker_id).unwrap(); + client + .bidi_streaming_tokenization_task_predict(&chunker_id, input_stream) + .await + }; + + let mut output_stream = response_stream .map_err(|error| Error::ChunkerRequestFailed { id: chunker_id.clone(), error, diff --git a/src/orchestrator/unary.rs b/src/orchestrator/unary.rs index a16b0e93..c387cf52 100644 --- a/src/orchestrator/unary.rs +++ b/src/orchestrator/unary.rs @@ -31,12 +31,15 @@ use super::{ }; use crate::{ clients::{ + chunker::{ + tokenize_whole_doc, ChunkerClient, DEFAULT_MODEL_ID as CHUNKER_DEFAULT_MODEL_ID, + }, detector::{ ContentAnalysisRequest, ContentAnalysisResponse, ContextDocsDetectionRequest, ContextType, GenerationDetectionRequest, TextContentsDetectorClient, TextContextDocDetectorClient, TextGenerationDetectorClient, }, - ChunkerClient, GenerationClient, + GenerationClient, }, models::{ ClassifiedGeneratedTextResult, ContextDocsResult, DetectionOnGenerationResult, @@ -763,14 +766,19 @@ pub async fn chunk( ) -> Result, Error> { let request = chunkers::ChunkerTokenizationTaskRequest { text }; debug!(%chunker_id, ?request, "sending chunker request"); - let client = ctx.clients.get_as::(&chunker_id).unwrap(); - let response = client - .tokenization_task_predict(&chunker_id, request) - .await - .map_err(|error| Error::ChunkerRequestFailed { - id: chunker_id.clone(), - error, - })?; + let response = if chunker_id == CHUNKER_DEFAULT_MODEL_ID { + tokenize_whole_doc(request) + } else { + let client = ctx.clients.get_as::(&chunker_id).unwrap(); + client + .tokenization_task_predict(&chunker_id, request) + .await + .map_err(|error| Error::ChunkerRequestFailed { + id: chunker_id.clone(), + error, + })? + }; + debug!(%chunker_id, ?response, "received chunker response"); Ok(response .results From 65babf9c20d105c18d99c4ed73f02bab96e7ac44 Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Tue, 15 Oct 2024 09:09:45 -0500 Subject: [PATCH 20/50] :art: Fix formatting Signed-off-by: gkumbhat --- src/clients.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/clients.rs b/src/clients.rs index d85c6b6b..392eac05 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -41,7 +41,6 @@ pub mod http; pub use http::HttpClient; pub mod chunker; -pub use chunker::ChunkerClient; pub mod detector; pub use detector::TextContentsDetectorClient; From 0bf8e20de76bb5a7fbb995b6a4dd080ebd542dcd Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Tue, 15 Oct 2024 13:24:37 -0500 Subject: [PATCH 21/50] :truck: Rename default model id for chunker to default chunker id Signed-off-by: gkumbhat --- src/clients/chunker.rs | 2 +- src/config.rs | 8 ++++---- src/orchestrator/streaming.rs | 6 ++---- src/orchestrator/unary.rs | 6 ++---- 4 files changed, 9 insertions(+), 13 deletions(-) diff --git a/src/clients/chunker.rs b/src/clients/chunker.rs index c6cc568d..65795653 100644 --- a/src/clients/chunker.rs +++ b/src/clients/chunker.rs @@ -39,7 +39,7 @@ use crate::{ const DEFAULT_PORT: u16 = 8085; const MODEL_ID_HEADER_NAME: &str = "mm-model-id"; /// Default chunker that returns span for entire text -pub const DEFAULT_MODEL_ID: &str = "whole_doc_chunker"; +pub const DEFAULT_CHUNKER_ID: &str = "whole_doc_chunker"; type StreamingTokenizationResult = Result>, Status>; diff --git a/src/config.rs b/src/config.rs index 89c4d559..14053db1 100644 --- a/src/config.rs +++ b/src/config.rs @@ -23,7 +23,7 @@ use std::{ use serde::Deserialize; use tracing::{debug, error, info, warn}; -use crate::clients::{chunker::DEFAULT_MODEL_ID, is_valid_hostname}; +use crate::clients::{chunker::DEFAULT_CHUNKER_ID, is_valid_hostname}; // Placeholder to add default allowed headers const DEFAULT_ALLOWED_HEADERS: &[&str] = &[]; @@ -308,7 +308,7 @@ impl OrchestratorConfig { ))); } // Chunker is valid - let valid_chunker = detector.chunker_id == DEFAULT_MODEL_ID + let valid_chunker = detector.chunker_id == DEFAULT_CHUNKER_ID || self .chunkers .as_ref() @@ -671,7 +671,7 @@ detectors: port: 9000 tls: detector chunker_id: sentence-fr - default_threshold: 0.5 + default_threshold: 0.5 "#; let config: OrchestratorConfig = serde_yml::from_str(s).unwrap(); assert!(config.passthrough_headers.is_empty()); @@ -699,7 +699,7 @@ detectors: port: 9000 tls: detector chunker_id: sentence-fr - default_threshold: 0.5 + default_threshold: 0.5 passthrough_headers: - test-header diff --git a/src/orchestrator/streaming.rs b/src/orchestrator/streaming.rs index d671407f..e6a735eb 100644 --- a/src/orchestrator/streaming.rs +++ b/src/orchestrator/streaming.rs @@ -29,9 +29,7 @@ use tracing::{debug, error, info, instrument}; use super::{get_chunker_ids, Context, Error, Orchestrator, StreamingClassificationWithGenTask}; use crate::{ clients::{ - chunker::{ - tokenize_whole_doc_stream, ChunkerClient, DEFAULT_MODEL_ID as CHUNKER_DEFAULT_MODEL_ID, - }, + chunker::{tokenize_whole_doc_stream, ChunkerClient, DEFAULT_CHUNKER_ID}, detector::ContentAnalysisRequest, GenerationClient, TextContentsDetectorClient, }, @@ -459,7 +457,7 @@ async fn chunk_broadcast_task( debug!(%chunker_id, "creating chunker output stream"); let id = chunker_id.clone(); // workaround for StreamExt::map_err - let response_stream = if chunker_id == CHUNKER_DEFAULT_MODEL_ID { + let response_stream = if chunker_id == DEFAULT_CHUNKER_ID { info!("Using default whole doc chunker"); let (response_tx, response_rx) = mpsc::channel(1); // Spawn task to collect input stream diff --git a/src/orchestrator/unary.rs b/src/orchestrator/unary.rs index c387cf52..b1746d61 100644 --- a/src/orchestrator/unary.rs +++ b/src/orchestrator/unary.rs @@ -31,9 +31,7 @@ use super::{ }; use crate::{ clients::{ - chunker::{ - tokenize_whole_doc, ChunkerClient, DEFAULT_MODEL_ID as CHUNKER_DEFAULT_MODEL_ID, - }, + chunker::{tokenize_whole_doc, ChunkerClient, DEFAULT_CHUNKER_ID}, detector::{ ContentAnalysisRequest, ContentAnalysisResponse, ContextDocsDetectionRequest, ContextType, GenerationDetectionRequest, TextContentsDetectorClient, @@ -766,7 +764,7 @@ pub async fn chunk( ) -> Result, Error> { let request = chunkers::ChunkerTokenizationTaskRequest { text }; debug!(%chunker_id, ?request, "sending chunker request"); - let response = if chunker_id == CHUNKER_DEFAULT_MODEL_ID { + let response = if chunker_id == DEFAULT_CHUNKER_ID { tokenize_whole_doc(request) } else { let client = ctx.clients.get_as::(&chunker_id).unwrap(); From a9e62c24216b9acc21ab2e2e2b6271e61ef55a7d Mon Sep 17 00:00:00 2001 From: pscoro <78318122+pscoro@users.noreply.github.com> Date: Thu, 17 Oct 2024 14:09:23 -0400 Subject: [PATCH 22/50] missing licence comments on recent new files (#232) Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> --- src/clients/detector/text_chat.rs | 17 +++++++++++++++++ src/clients/detector/text_contents.rs | 17 +++++++++++++++++ src/clients/detector/text_context_doc.rs | 17 +++++++++++++++++ src/clients/detector/text_generation.rs | 17 +++++++++++++++++ src/clients/errors.rs | 17 +++++++++++++++++ src/clients/http.rs | 17 +++++++++++++++++ src/clients/openai.rs | 17 +++++++++++++++++ src/clients/tgis.rs | 1 + 8 files changed, 120 insertions(+) diff --git a/src/clients/detector/text_chat.rs b/src/clients/detector/text_chat.rs index 39214c43..6aecff3b 100644 --- a/src/clients/detector/text_chat.rs +++ b/src/clients/detector/text_chat.rs @@ -1,3 +1,20 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + use async_trait::async_trait; use super::DEFAULT_PORT; diff --git a/src/clients/detector/text_contents.rs b/src/clients/detector/text_contents.rs index a146e6a4..b8959c0f 100644 --- a/src/clients/detector/text_contents.rs +++ b/src/clients/detector/text_contents.rs @@ -1,3 +1,20 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + use async_trait::async_trait; use hyper::{HeaderMap, StatusCode}; use serde::{Deserialize, Serialize}; diff --git a/src/clients/detector/text_context_doc.rs b/src/clients/detector/text_context_doc.rs index 6ef4661d..ae9973e8 100644 --- a/src/clients/detector/text_context_doc.rs +++ b/src/clients/detector/text_context_doc.rs @@ -1,3 +1,20 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + use async_trait::async_trait; use hyper::{HeaderMap, StatusCode}; use serde::{Deserialize, Serialize}; diff --git a/src/clients/detector/text_generation.rs b/src/clients/detector/text_generation.rs index b1190daa..98236dfe 100644 --- a/src/clients/detector/text_generation.rs +++ b/src/clients/detector/text_generation.rs @@ -1,3 +1,20 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + use async_trait::async_trait; use hyper::{HeaderMap, StatusCode}; use serde::Serialize; diff --git a/src/clients/errors.rs b/src/clients/errors.rs index e8f638a8..e630e21a 100644 --- a/src/clients/errors.rs +++ b/src/clients/errors.rs @@ -1,3 +1,20 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + use std::error::Error as _; use hyper::StatusCode; diff --git a/src/clients/http.rs b/src/clients/http.rs index fdd6aef6..862db811 100644 --- a/src/clients/http.rs +++ b/src/clients/http.rs @@ -1,3 +1,20 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + use hyper::StatusCode; use reqwest::Response; use tracing::error; diff --git a/src/clients/openai.rs b/src/clients/openai.rs index 584ba436..66eea424 100644 --- a/src/clients/openai.rs +++ b/src/clients/openai.rs @@ -1,3 +1,20 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + use std::collections::HashMap; use async_trait::async_trait; diff --git a/src/clients/tgis.rs b/src/clients/tgis.rs index 1195d642..09d01395 100644 --- a/src/clients/tgis.rs +++ b/src/clients/tgis.rs @@ -14,6 +14,7 @@ limitations under the License. */ + use async_trait::async_trait; use axum::http::HeaderMap; use futures::{StreamExt, TryStreamExt}; From 7587916d1de7514c4c7dd7a24a437d4f411d7a65 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 14 Oct 2024 12:58:06 -0300 Subject: [PATCH 23/50] Implement chat detection endpoint Signed-off-by: Mateus Devino --- src/clients/detector/text_chat.rs | 53 ++++++++++++-- src/models.rs | 40 ++++++++++- src/orchestrator.rs | 34 ++++++++- src/orchestrator/unary.rs | 115 +++++++++++++++++++++++++++--- src/server.rs | 25 ++++++- 5 files changed, 247 insertions(+), 20 deletions(-) diff --git a/src/clients/detector/text_chat.rs b/src/clients/detector/text_chat.rs index 39214c43..fcc5a35b 100644 --- a/src/clients/detector/text_chat.rs +++ b/src/clients/detector/text_chat.rs @@ -1,10 +1,13 @@ use async_trait::async_trait; +use hyper::{HeaderMap, StatusCode}; +use serde::Serialize; -use super::DEFAULT_PORT; +use super::{DetectorError, DEFAULT_PORT, DETECTOR_ID_HEADER_NAME}; use crate::{ - clients::{create_http_client, Client, HttpClient}, + clients::{create_http_client, openai::Message, Client, Error, HttpClient}, config::ServiceConfig, health::HealthCheckResult, + models::DetectionResult, }; #[cfg_attr(test, faux::create)] @@ -29,9 +32,34 @@ impl TextChatDetectorClient { } } - pub async fn text_chat(&self) { - let _url = self.client.base_url().join("/api/v1/text/chat").unwrap(); - todo!() + pub async fn text_chat( + &self, + model_id: &str, + request: ChatDetectionRequest, + headers: HeaderMap, + ) -> Result, Error> { + let url = self.client.base_url().join("/api/v1/text/chat").unwrap(); + let response = self + .client + .post(url) + .headers(headers) + .header(DETECTOR_ID_HEADER_NAME, model_id) + .json(&request) + .send() + .await?; + if response.status() == StatusCode::OK { + Ok(response.json().await?) + } else { + let code = response.status().as_u16(); + let error = response + .json::() + .await + .unwrap_or(DetectorError { + code, + message: "".into(), + }); + Err(error.into()) + } } } @@ -50,3 +78,18 @@ impl Client for TextChatDetectorClient { } } } + +/// A struct representing a request to a detector compatible with the +/// /api/v1/text/chat endpoint. +// #[cfg_attr(test, derive(PartialEq))] +#[derive(Debug, Serialize)] +pub struct ChatDetectionRequest { + /// Chat messages to run detection on + pub messages: Vec, +} + +impl ChatDetectionRequest { + pub fn new(messages: Vec) -> Self { + Self { messages } + } +} diff --git a/src/models.rs b/src/models.rs index ca819628..e8771051 100644 --- a/src/models.rs +++ b/src/models.rs @@ -22,7 +22,10 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; use crate::{ - clients::detector::{ContentAnalysisResponse, ContextType}, + clients::{ + self, + detector::{ContentAnalysisResponse, ContextType}, + }, health::HealthCheckCache, pb, }; @@ -939,6 +942,41 @@ pub struct ContextDocsResult { pub detections: Vec, } +/// The request format expected in the /api/v2/text/detect/generated endpoint. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ChatDetectionHttpRequest { + /// The map of detectors to be used, along with their respective parameters, e.g. thresholds. + pub detectors: HashMap, + + // The list of messages to run detections on. + pub messages: Vec, +} + +impl ChatDetectionHttpRequest { + /// Upfront validation of user request + pub fn validate(&self) -> Result<(), ValidationError> { + // Validate required parameters + if self.detectors.is_empty() { + return Err(ValidationError::Required("detectors".into())); + } + if self.messages.is_empty() { + return Err(ValidationError::Required("messages".into())); + } + + // Validate detector params + validate_detector_params(&self.detectors)?; + + Ok(()) + } +} + +/// The response format of the /api/v2/text/detection/chat endpoint +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ChatDetectionResult { + /// Detection results + pub detections: Vec, +} + /// The request format expected in the /api/v2/text/detect/generated endpoint. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct DetectionOnGeneratedHttpRequest { diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 967ca427..26a87976 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -30,6 +30,7 @@ use uuid::Uuid; use crate::{ clients::{ chunker::ChunkerClient, + self, detector::{ text_context_doc::ContextType, TextChatDetectorClient, TextContextDocDetectorClient, TextGenerationDetectorClient, @@ -40,9 +41,9 @@ use crate::{ config::{DetectorType, GenerationProvider, OrchestratorConfig}, health::HealthCheckCache, models::{ - ContextDocsHttpRequest, DetectionOnGeneratedHttpRequest, DetectorParams, - GenerationWithDetectionHttpRequest, GuardrailsConfig, GuardrailsHttpRequest, - GuardrailsTextGenerationParameters, TextContentDetectionHttpRequest, + ChatDetectionHttpRequest, ContextDocsHttpRequest, DetectionOnGeneratedHttpRequest, + DetectorParams, GenerationWithDetectionHttpRequest, GuardrailsConfig, + GuardrailsHttpRequest, GuardrailsTextGenerationParameters, TextContentDetectionHttpRequest, }, }; @@ -382,6 +383,33 @@ impl ContextDocsDetectionTask { } } +/// Task for the /api/v2/text/detection/chat endpoint +#[derive(Debug)] +pub struct ChatDetectionTask { + /// Request unique identifier + pub request_id: Uuid, + + /// Detectors configuration + pub detectors: HashMap, + + // Messages to run detection on + pub messages: Vec, + + // Headermap + pub headers: HeaderMap, +} + +impl ChatDetectionTask { + pub fn new(request_id: Uuid, request: ChatDetectionHttpRequest, headers: HeaderMap) -> Self { + Self { + request_id, + detectors: request.detectors, + messages: request.messages, + headers, + } + } +} + /// Task for the /api/v2/text/detection/generated endpoint #[derive(Debug)] pub struct DetectionOnGenerationTask { diff --git a/src/orchestrator/unary.rs b/src/orchestrator/unary.rs index b1746d61..ab114a3a 100644 --- a/src/orchestrator/unary.rs +++ b/src/orchestrator/unary.rs @@ -25,7 +25,7 @@ use futures::{ use tracing::{debug, error, info, instrument}; use super::{ - apply_masks, get_chunker_ids, Chunk, ClassificationWithGenTask, Context, + apply_masks, get_chunker_ids, ChatDetectionTask, Chunk, ClassificationWithGenTask, Context, ContextDocsDetectionTask, DetectionOnGenerationTask, Error, GenerationWithDetectionTask, Orchestrator, TextContentDetectionTask, }; @@ -33,17 +33,20 @@ use crate::{ clients::{ chunker::{tokenize_whole_doc, ChunkerClient, DEFAULT_CHUNKER_ID}, detector::{ - ContentAnalysisRequest, ContentAnalysisResponse, ContextDocsDetectionRequest, - ContextType, GenerationDetectionRequest, TextContentsDetectorClient, - TextContextDocDetectorClient, TextGenerationDetectorClient, + ChatDetectionRequest, ContentAnalysisRequest, ContentAnalysisResponse, + ContextDocsDetectionRequest, ContextType, GenerationDetectionRequest, + TextChatDetectorClient, TextContentsDetectorClient, TextContextDocDetectorClient, + TextGenerationDetectorClient, }, + openai::Message, GenerationClient, }, models::{ - ClassifiedGeneratedTextResult, ContextDocsResult, DetectionOnGenerationResult, - DetectionResult, DetectorParams, GenerationWithDetectionResult, - GuardrailsTextGenerationParameters, InputWarning, InputWarningReason, - TextContentDetectionResult, TextGenTokenClassificationResults, TokenClassificationResult, + ChatDetectionResult, ClassifiedGeneratedTextResult, ContextDocsResult, + DetectionOnGenerationResult, DetectionResult, DetectorParams, + GenerationWithDetectionResult, GuardrailsTextGenerationParameters, InputWarning, + InputWarningReason, TextContentDetectionResult, TextGenTokenClassificationResults, + TokenClassificationResult, }, orchestrator::UNSUITABLE_INPUT_MESSAGE, pb::caikit::runtime::chunkers, @@ -447,6 +450,61 @@ impl Orchestrator { } } } + + /// Handles detections on chat messages (without performing generation) + pub async fn handle_chat_detection( + &self, + task: ChatDetectionTask, + ) -> Result { + info!( + request_id = ?task.request_id, + detectors = ?task.detectors, + "handling detection on chat content task" + ); + let ctx = self.ctx.clone(); + let headers = task.headers; + + let task_handle = tokio::spawn(async move { + // call detection + let detections = try_join_all( + task.detectors + .iter() + .map(|(detector_id, detector_params)| { + let ctx = ctx.clone(); + let detector_id = detector_id.clone(); + let detector_params = detector_params.clone(); + let messages = task.messages.clone(); + let headers = headers.clone(); + async { + detect_for_chat(ctx, detector_id, detector_params, messages, headers) + .await + } + }) + .collect::>(), + ) + .await? + .into_iter() + .flatten() + .collect::>(); + + Ok(ChatDetectionResult { detections }) + }); + match task_handle.await { + // Task completed successfully + Ok(Ok(result)) => Ok(result), + // Task failed, return error propagated from child task that failed + Ok(Err(error)) => { + error!(request_id = ?task.request_id, %error, "detection on chat content task failed"); + Err(error) + } + // Task cancelled or panicked + Err(error) => { + let error = error.into(); + error!(request_id = ?task.request_id, %error, "detection on chat content task failed"); + Err(error) + } + } + } } /// Handles input detection task. @@ -711,6 +769,47 @@ pub async fn detect_for_generation( Ok::, Error>(response) } +/// Calls a detector that implements the /api/v1/text/chat endpoint +pub async fn detect_for_chat( + ctx: Arc, + detector_id: String, + detector_params: DetectorParams, + messages: Vec, + headers: HeaderMap, +) -> Result, Error> { + let detector_id = detector_id.clone(); + let threshold = detector_params.threshold().unwrap_or( + detector_params.threshold().unwrap_or( + ctx.config + .detectors + .get(&detector_id) + .ok_or_else(|| Error::DetectorNotFound(detector_id.clone()))? + .default_threshold, + ), + ); + let request = ChatDetectionRequest::new(messages.clone()); + debug!(%detector_id, ?request, "sending chat detector request"); + let client = ctx + .clients + .get_as::(&detector_id) + .unwrap(); + let response = client + .text_chat(&detector_id, request, headers) + .await + .map(|results| { + results + .into_iter() + .filter(|detection| detection.score > threshold) + .collect() + }) + .map_err(|error| Error::DetectorRequestFailed { + id: detector_id.clone(), + error, + })?; + debug!(%detector_id, ?response, "received chat detector response"); + Ok::, Error>(response) +} + /// Calls a detector that implements the /api/v1/text/doc endpoint pub async fn detect_for_context( ctx: Arc, diff --git a/src/server.rs b/src/server.rs index ff83c73d..b891efef 100644 --- a/src/server.rs +++ b/src/server.rs @@ -52,9 +52,9 @@ use webpki::types::{CertificateDer, PrivateKeyDer}; use crate::{ models::{self, InfoParams, InfoResponse}, orchestrator::{ - self, ClassificationWithGenTask, ContextDocsDetectionTask, DetectionOnGenerationTask, - GenerationWithDetectionTask, Orchestrator, StreamingClassificationWithGenTask, - TextContentDetectionTask, + self, ChatDetectionTask, ClassificationWithGenTask, ContextDocsDetectionTask, + DetectionOnGenerationTask, GenerationWithDetectionTask, Orchestrator, + StreamingClassificationWithGenTask, TextContentDetectionTask, }, tracing_utils, }; @@ -171,6 +171,10 @@ pub async fn run( &format!("{}/detection/content", TEXT_API_PREFIX), post(detection_content), ) + .route( + &format!("{}/detection/chat", TEXT_API_PREFIX), + post(detect_chat), + ) .route( &format!("{}/detection/context", TEXT_API_PREFIX), post(detect_context_documents), @@ -430,6 +434,21 @@ async fn detect_context_documents( } } +#[instrument(skip_all)] +async fn detect_chat( + State(state): State>, + headers: HeaderMap, + WithRejection(Json(request), _): WithRejection, Error>, +) -> Result { + let request_id = Uuid::new_v4(); + request.validate()?; + let task = ChatDetectionTask::new(request_id, request, headers); + match state.orchestrator.handle_chat_detection(task).await { + Ok(response) => Ok(Json(response).into_response()), + Err(error) => Err(error.into()), + } +} + #[instrument(skip_all)] async fn detect_generated( State(state): State>, From 61f240256ae089a4dd52889ff11274aaa4b99751 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 15 Oct 2024 17:55:27 -0300 Subject: [PATCH 24/50] Replace openai Message by String Signed-off-by: Mateus Devino --- src/clients/detector/text_chat.rs | 6 +++--- src/models.rs | 7 ++----- src/orchestrator.rs | 3 +-- src/orchestrator/unary.rs | 3 +-- 4 files changed, 7 insertions(+), 12 deletions(-) diff --git a/src/clients/detector/text_chat.rs b/src/clients/detector/text_chat.rs index fcc5a35b..fa51e9b7 100644 --- a/src/clients/detector/text_chat.rs +++ b/src/clients/detector/text_chat.rs @@ -4,7 +4,7 @@ use serde::Serialize; use super::{DetectorError, DEFAULT_PORT, DETECTOR_ID_HEADER_NAME}; use crate::{ - clients::{create_http_client, openai::Message, Client, Error, HttpClient}, + clients::{create_http_client, Client, Error, HttpClient}, config::ServiceConfig, health::HealthCheckResult, models::DetectionResult, @@ -85,11 +85,11 @@ impl Client for TextChatDetectorClient { #[derive(Debug, Serialize)] pub struct ChatDetectionRequest { /// Chat messages to run detection on - pub messages: Vec, + pub messages: Vec, } impl ChatDetectionRequest { - pub fn new(messages: Vec) -> Self { + pub fn new(messages: Vec) -> Self { Self { messages } } } diff --git a/src/models.rs b/src/models.rs index e8771051..8e5f706d 100644 --- a/src/models.rs +++ b/src/models.rs @@ -22,10 +22,7 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; use crate::{ - clients::{ - self, - detector::{ContentAnalysisResponse, ContextType}, - }, + clients::detector::{ContentAnalysisResponse, ContextType}, health::HealthCheckCache, pb, }; @@ -949,7 +946,7 @@ pub struct ChatDetectionHttpRequest { pub detectors: HashMap, // The list of messages to run detections on. - pub messages: Vec, + pub messages: Vec, } impl ChatDetectionHttpRequest { diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 26a87976..ed0a95ab 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -30,7 +30,6 @@ use uuid::Uuid; use crate::{ clients::{ chunker::ChunkerClient, - self, detector::{ text_context_doc::ContextType, TextChatDetectorClient, TextContextDocDetectorClient, TextGenerationDetectorClient, @@ -393,7 +392,7 @@ pub struct ChatDetectionTask { pub detectors: HashMap, // Messages to run detection on - pub messages: Vec, + pub messages: Vec, // Headermap pub headers: HeaderMap, diff --git a/src/orchestrator/unary.rs b/src/orchestrator/unary.rs index ab114a3a..57c68e17 100644 --- a/src/orchestrator/unary.rs +++ b/src/orchestrator/unary.rs @@ -38,7 +38,6 @@ use crate::{ TextChatDetectorClient, TextContentsDetectorClient, TextContextDocDetectorClient, TextGenerationDetectorClient, }, - openai::Message, GenerationClient, }, models::{ @@ -774,7 +773,7 @@ pub async fn detect_for_chat( ctx: Arc, detector_id: String, detector_params: DetectorParams, - messages: Vec, + messages: Vec, headers: HeaderMap, ) -> Result, Error> { let detector_id = detector_id.clone(); From acbf7189c720d5a79e5287f9a03dcbcf50fa8cc1 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 15 Oct 2024 18:08:00 -0300 Subject: [PATCH 25/50] Add detector_params Signed-off-by: Mateus Devino --- src/clients/detector/text_chat.rs | 11 ++++++++--- src/orchestrator/unary.rs | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/clients/detector/text_chat.rs b/src/clients/detector/text_chat.rs index fa51e9b7..2bf3954a 100644 --- a/src/clients/detector/text_chat.rs +++ b/src/clients/detector/text_chat.rs @@ -7,7 +7,7 @@ use crate::{ clients::{create_http_client, Client, Error, HttpClient}, config::ServiceConfig, health::HealthCheckResult, - models::DetectionResult, + models::{DetectionResult, DetectorParams}, }; #[cfg_attr(test, faux::create)] @@ -86,10 +86,15 @@ impl Client for TextChatDetectorClient { pub struct ChatDetectionRequest { /// Chat messages to run detection on pub messages: Vec, + + pub detector_params: DetectorParams, } impl ChatDetectionRequest { - pub fn new(messages: Vec) -> Self { - Self { messages } + pub fn new(messages: Vec, detector_params: DetectorParams) -> Self { + Self { + messages, + detector_params, + } } } diff --git a/src/orchestrator/unary.rs b/src/orchestrator/unary.rs index 57c68e17..6a86f613 100644 --- a/src/orchestrator/unary.rs +++ b/src/orchestrator/unary.rs @@ -786,7 +786,7 @@ pub async fn detect_for_chat( .default_threshold, ), ); - let request = ChatDetectionRequest::new(messages.clone()); + let request = ChatDetectionRequest::new(messages.clone(), detector_params.clone()); debug!(%detector_id, ?request, "sending chat detector request"); let client = ctx .clients From 0fa1324bc6f686975bd3da77899f2a7b7d5bd425 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 15 Oct 2024 19:12:30 -0300 Subject: [PATCH 26/50] Revert "Replace openai Message by String" This reverts commit 1cfcc1ea4d7ad578cc2c925f2aaa32d86ee28261. Signed-off-by: Mateus Devino --- src/clients/detector/text_chat.rs | 6 +++--- src/models.rs | 7 +++++-- src/orchestrator.rs | 3 ++- src/orchestrator/unary.rs | 3 ++- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/clients/detector/text_chat.rs b/src/clients/detector/text_chat.rs index 2bf3954a..9d4df05a 100644 --- a/src/clients/detector/text_chat.rs +++ b/src/clients/detector/text_chat.rs @@ -4,7 +4,7 @@ use serde::Serialize; use super::{DetectorError, DEFAULT_PORT, DETECTOR_ID_HEADER_NAME}; use crate::{ - clients::{create_http_client, Client, Error, HttpClient}, + clients::{create_http_client, openai::Message, Client, Error, HttpClient}, config::ServiceConfig, health::HealthCheckResult, models::{DetectionResult, DetectorParams}, @@ -85,13 +85,13 @@ impl Client for TextChatDetectorClient { #[derive(Debug, Serialize)] pub struct ChatDetectionRequest { /// Chat messages to run detection on - pub messages: Vec, + pub messages: Vec, pub detector_params: DetectorParams, } impl ChatDetectionRequest { - pub fn new(messages: Vec, detector_params: DetectorParams) -> Self { + pub fn new(messages: Vec, detector_params: DetectorParams) -> Self { Self { messages, detector_params, diff --git a/src/models.rs b/src/models.rs index 8e5f706d..e8771051 100644 --- a/src/models.rs +++ b/src/models.rs @@ -22,7 +22,10 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; use crate::{ - clients::detector::{ContentAnalysisResponse, ContextType}, + clients::{ + self, + detector::{ContentAnalysisResponse, ContextType}, + }, health::HealthCheckCache, pb, }; @@ -946,7 +949,7 @@ pub struct ChatDetectionHttpRequest { pub detectors: HashMap, // The list of messages to run detections on. - pub messages: Vec, + pub messages: Vec, } impl ChatDetectionHttpRequest { diff --git a/src/orchestrator.rs b/src/orchestrator.rs index ed0a95ab..98c9e164 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -29,6 +29,7 @@ use uuid::Uuid; use crate::{ clients::{ + self, chunker::ChunkerClient, detector::{ text_context_doc::ContextType, TextChatDetectorClient, TextContextDocDetectorClient, @@ -392,7 +393,7 @@ pub struct ChatDetectionTask { pub detectors: HashMap, // Messages to run detection on - pub messages: Vec, + pub messages: Vec, // Headermap pub headers: HeaderMap, diff --git a/src/orchestrator/unary.rs b/src/orchestrator/unary.rs index 6a86f613..b004b1ab 100644 --- a/src/orchestrator/unary.rs +++ b/src/orchestrator/unary.rs @@ -38,6 +38,7 @@ use crate::{ TextChatDetectorClient, TextContentsDetectorClient, TextContextDocDetectorClient, TextGenerationDetectorClient, }, + openai::Message, GenerationClient, }, models::{ @@ -773,7 +774,7 @@ pub async fn detect_for_chat( ctx: Arc, detector_id: String, detector_params: DetectorParams, - messages: Vec, + messages: Vec, headers: HeaderMap, ) -> Result, Error> { let detector_id = detector_id.clone(); From b923b9425d1e85d1281bd3d286bf9d6aae16c5b4 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 17 Oct 2024 14:15:53 -0300 Subject: [PATCH 27/50] Add header filtering Signed-off-by: Mateus Devino --- src/server.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/server.rs b/src/server.rs index b891efef..a836295b 100644 --- a/src/server.rs +++ b/src/server.rs @@ -442,6 +442,7 @@ async fn detect_chat( ) -> Result { let request_id = Uuid::new_v4(); request.validate()?; + let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); let task = ChatDetectionTask::new(request_id, request, headers); match state.orchestrator.handle_chat_detection(task).await { Ok(response) => Ok(Json(response).into_response()), From acd7c1f7870c9ee24617bd575cae2fff8d10dc11 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 17 Oct 2024 14:16:20 -0300 Subject: [PATCH 28/50] Log chat detector requests and responses Signed-off-by: Mateus Devino --- src/clients/detector/text_chat.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/clients/detector/text_chat.rs b/src/clients/detector/text_chat.rs index 9d4df05a..68c30bef 100644 --- a/src/clients/detector/text_chat.rs +++ b/src/clients/detector/text_chat.rs @@ -39,14 +39,17 @@ impl TextChatDetectorClient { headers: HeaderMap, ) -> Result, Error> { let url = self.client.base_url().join("/api/v1/text/chat").unwrap(); - let response = self + let request = self .client .post(url) .headers(headers) .header(DETECTOR_ID_HEADER_NAME, model_id) - .json(&request) - .send() - .await?; + .json(&request); + + tracing::debug!("Request being sent to chat detector: {:?}", request); + let response = request.send().await?; + tracing::debug!("Response received from chat detector: {:?}", response); + if response.status() == StatusCode::OK { Ok(response.json().await?) } else { From 5959dd3ae0b5edbc4f4c4644d3cf4fca9a4047de Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 17 Oct 2024 20:10:03 -0300 Subject: [PATCH 29/50] Add content type validation Signed-off-by: Mateus Devino --- src/models.rs | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/models.rs b/src/models.rs index e8771051..44fae16c 100644 --- a/src/models.rs +++ b/src/models.rs @@ -25,6 +25,7 @@ use crate::{ clients::{ self, detector::{ContentAnalysisResponse, ContextType}, + openai::Content, }, health::HealthCheckCache, pb, @@ -963,11 +964,46 @@ impl ChatDetectionHttpRequest { return Err(ValidationError::Required("messages".into())); } + // validate messages + self.validate_messages()?; + // Validate detector params validate_detector_params(&self.detectors)?; Ok(()) } + + /// Validates if message contents are either a string or a content type of type "text" + fn validate_messages(&self) -> Result<(), ValidationError> { + for message in &self.messages { + match &message.content { + Some(content) => self.validate_content_type(content)?, + None => { + return Err(ValidationError::Invalid( + "Message content cannot be empty".into(), + )) + } + } + } + Ok(()) + } + + /// Validates if content type array contains only text messages + fn validate_content_type(&self, content: &Content) -> Result<(), ValidationError> { + match content { + Content::Array(content) => { + for content_part in content { + if content_part.r#type != "text" { + return Err(ValidationError::Invalid( + "Only content of type text is allowed".into(), + )); + } + } + Ok(()) + } + Content::String(_) => Ok(()), // if message.content is a string, it is a valid message + } + } } /// The response format of the /api/v2/text/detection/chat endpoint From c7936989e0dbb2665a9fc1747a1d6efda4d6af99 Mon Sep 17 00:00:00 2001 From: Mateus Devino <19861348+mdevino@users.noreply.github.com> Date: Fri, 18 Oct 2024 13:41:12 -0300 Subject: [PATCH 30/50] Update src/orchestrator/unary.rs Co-authored-by: Gaurav Kumbhat Signed-off-by: Mateus Devino <19861348+mdevino@users.noreply.github.com> --- src/orchestrator/unary.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/orchestrator/unary.rs b/src/orchestrator/unary.rs index b004b1ab..87c89a9a 100644 --- a/src/orchestrator/unary.rs +++ b/src/orchestrator/unary.rs @@ -494,7 +494,7 @@ impl Orchestrator { Ok(Ok(result)) => Ok(result), // Task failed, return error propagated from child task that failed Ok(Err(error)) => { - error!(request_id = ?task.request_id, %error, "detection on chat content task failed"); + error!(request_id = ?task.request_id, %error, "detection task on chat failed"); Err(error) } // Task cancelled or panicked From ac4008ae403e6fa3364e3e4d065c9ee885419471 Mon Sep 17 00:00:00 2001 From: Mateus Devino <19861348+mdevino@users.noreply.github.com> Date: Fri, 18 Oct 2024 13:41:26 -0300 Subject: [PATCH 31/50] Update src/orchestrator/unary.rs Co-authored-by: Gaurav Kumbhat Signed-off-by: Mateus Devino <19861348+mdevino@users.noreply.github.com> --- src/orchestrator/unary.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/orchestrator/unary.rs b/src/orchestrator/unary.rs index 87c89a9a..4d5b6140 100644 --- a/src/orchestrator/unary.rs +++ b/src/orchestrator/unary.rs @@ -500,7 +500,7 @@ impl Orchestrator { // Task cancelled or panicked Err(error) => { let error = error.into(); - error!(request_id = ?task.request_id, %error, "detection on chat content task failed"); + error!(request_id = ?task.request_id, %error, "detection task on chat failed"); Err(error) } } From 3f96193efb71f49d84b8799f6eb2fbd043eb6c64 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Fri, 18 Oct 2024 16:24:59 -0300 Subject: [PATCH 32/50] Extract chat endpoint as a constant Signed-off-by: Mateus Devino --- src/clients/detector/text_chat.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/clients/detector/text_chat.rs b/src/clients/detector/text_chat.rs index 68c30bef..e284dabc 100644 --- a/src/clients/detector/text_chat.rs +++ b/src/clients/detector/text_chat.rs @@ -10,6 +10,8 @@ use crate::{ models::{DetectionResult, DetectorParams}, }; +const CHAT_DETECTOR_ENDPOINT: &str = "/api/v1/text/chat"; + #[cfg_attr(test, faux::create)] #[derive(Clone)] pub struct TextChatDetectorClient { @@ -38,7 +40,7 @@ impl TextChatDetectorClient { request: ChatDetectionRequest, headers: HeaderMap, ) -> Result, Error> { - let url = self.client.base_url().join("/api/v1/text/chat").unwrap(); + let url = self.client.base_url().join(CHAT_DETECTOR_ENDPOINT).unwrap(); let request = self .client .post(url) From 47a8dbdbc70b60f1d02640c80e3b2a077e1c2029 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Fri, 18 Oct 2024 16:45:23 -0300 Subject: [PATCH 33/50] Split text-specific validation Signed-off-by: Mateus Devino --- src/models.rs | 9 ++++++--- src/server.rs | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/models.rs b/src/models.rs index 44fae16c..271c82aa 100644 --- a/src/models.rs +++ b/src/models.rs @@ -964,10 +964,13 @@ impl ChatDetectionHttpRequest { return Err(ValidationError::Required("messages".into())); } - // validate messages - self.validate_messages()?; + Ok(()) + } - // Validate detector params + /// Validates for the "/api/v1/text/chat" endpoint. + pub fn validate_for_text(&self) -> Result<(), ValidationError> { + self.validate()?; + self.validate_messages()?; validate_detector_params(&self.detectors)?; Ok(()) diff --git a/src/server.rs b/src/server.rs index a836295b..ed921f9e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -441,7 +441,7 @@ async fn detect_chat( WithRejection(Json(request), _): WithRejection, Error>, ) -> Result { let request_id = Uuid::new_v4(); - request.validate()?; + request.validate_for_text()?; let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); let task = ChatDetectionTask::new(request_id, request, headers); match state.orchestrator.handle_chat_detection(task).await { From 36d33cb5c46cfe7d48f2326549f11cdbd1f32905 Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Tue, 22 Oct 2024 11:03:53 -0500 Subject: [PATCH 34/50] :bug: Fix threshold getting passed through beyond orchestrator processing Signed-off-by: gkumbhat --- src/models.rs | 10 ++++++---- src/orchestrator/streaming.rs | 2 ++ src/orchestrator/unary.rs | 10 +++++----- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/models.rs b/src/models.rs index 271c82aa..08e8d3ed 100644 --- a/src/models.rs +++ b/src/models.rs @@ -54,8 +54,8 @@ impl DetectorParams { } /// Threshold to filter detector results by score. - pub fn threshold(&self) -> Option { - self.0.get("threshold").and_then(|v| v.as_f64()) + pub fn threshold(&mut self) -> Option { + self.0.remove("threshold").and_then(|v| v.as_f64()) } } @@ -1272,9 +1272,11 @@ mod tests { { "threshold": 0.2 }"#; - let value: DetectorParams = serde_json::from_str(value_json)?; + let mut value: DetectorParams = serde_json::from_str(value_json)?; assert_eq!(value.threshold(), Some(0.2)); - let value = DetectorParams::new(); + assert!(!value.contains_key("threshold")); + let mut value = DetectorParams::new(); + assert!(!value.contains_key("threshold")); assert_eq!(value.threshold(), None); Ok(()) } diff --git a/src/orchestrator/streaming.rs b/src/orchestrator/streaming.rs index e6a735eb..aa4dbf35 100644 --- a/src/orchestrator/streaming.rs +++ b/src/orchestrator/streaming.rs @@ -267,6 +267,8 @@ async fn streaming_output_detection_task( debug!("spawning detection tasks"); let mut detection_streams = Vec::with_capacity(detectors.len()); for (detector_id, detector_params) in detectors.iter() { + // Create a mutable copy of the parameters, so that we can modify it based on processing + let mut detector_params = detector_params.clone(); let detector_id = detector_id.to_string(); let chunker_id = ctx.config.get_chunker_id(&detector_id).unwrap(); diff --git a/src/orchestrator/unary.rs b/src/orchestrator/unary.rs index 4d5b6140..58396696 100644 --- a/src/orchestrator/unary.rs +++ b/src/orchestrator/unary.rs @@ -617,7 +617,7 @@ pub async fn detect( ctx: Arc, detector_id: String, default_threshold: f64, - detector_params: DetectorParams, + mut detector_params: DetectorParams, chunks: Vec, headers: HeaderMap, ) -> Result, Error> { @@ -676,7 +676,7 @@ pub async fn detect_content( ctx: Arc, detector_id: String, default_threshold: f64, - detector_params: DetectorParams, + mut detector_params: DetectorParams, chunks: Vec, headers: HeaderMap, ) -> Result, Error> { @@ -731,7 +731,7 @@ pub async fn detect_content( pub async fn detect_for_generation( ctx: Arc, detector_id: String, - detector_params: DetectorParams, + mut detector_params: DetectorParams, prompt: String, generated_text: String, headers: HeaderMap, @@ -773,7 +773,7 @@ pub async fn detect_for_generation( pub async fn detect_for_chat( ctx: Arc, detector_id: String, - detector_params: DetectorParams, + mut detector_params: DetectorParams, messages: Vec, headers: HeaderMap, ) -> Result, Error> { @@ -814,7 +814,7 @@ pub async fn detect_for_chat( pub async fn detect_for_context( ctx: Arc, detector_id: String, - detector_params: DetectorParams, + mut detector_params: DetectorParams, content: String, context_type: ContextType, context: Vec, From a9ae0407cd2ec3477f3758203149783ea1741923 Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Wed, 23 Oct 2024 09:51:51 -0500 Subject: [PATCH 35/50] :truck: Rename threshold to pop threshold to make it intuitive Signed-off-by: gkumbhat --- src/models.rs | 6 +++--- src/orchestrator/streaming.rs | 2 +- src/orchestrator/unary.rs | 16 ++++++++-------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/models.rs b/src/models.rs index 08e8d3ed..b79bb120 100644 --- a/src/models.rs +++ b/src/models.rs @@ -54,7 +54,7 @@ impl DetectorParams { } /// Threshold to filter detector results by score. - pub fn threshold(&mut self) -> Option { + pub fn pop_threshold(&mut self) -> Option { self.0.remove("threshold").and_then(|v| v.as_f64()) } } @@ -1273,11 +1273,11 @@ mod tests { "threshold": 0.2 }"#; let mut value: DetectorParams = serde_json::from_str(value_json)?; - assert_eq!(value.threshold(), Some(0.2)); + assert_eq!(value.pop_threshold(), Some(0.2)); assert!(!value.contains_key("threshold")); let mut value = DetectorParams::new(); assert!(!value.contains_key("threshold")); - assert_eq!(value.threshold(), None); + assert_eq!(value.pop_threshold(), None); Ok(()) } } diff --git a/src/orchestrator/streaming.rs b/src/orchestrator/streaming.rs index aa4dbf35..dd0007af 100644 --- a/src/orchestrator/streaming.rs +++ b/src/orchestrator/streaming.rs @@ -278,7 +278,7 @@ async fn streaming_output_detection_task( // Get the default threshold to use if threshold is not provided by the user let default_threshold = detector_config.default_threshold; - let threshold = detector_params.threshold().unwrap_or(default_threshold); + let threshold = detector_params.pop_threshold().unwrap_or(default_threshold); // Create detection stream let (detector_tx, detector_rx) = mpsc::channel(1024); diff --git a/src/orchestrator/unary.rs b/src/orchestrator/unary.rs index 58396696..1a496c21 100644 --- a/src/orchestrator/unary.rs +++ b/src/orchestrator/unary.rs @@ -622,7 +622,7 @@ pub async fn detect( headers: HeaderMap, ) -> Result, Error> { let detector_id = detector_id.clone(); - let threshold = detector_params.threshold().unwrap_or(default_threshold); + let threshold = detector_params.pop_threshold().unwrap_or(default_threshold); let contents: Vec<_> = chunks.iter().map(|chunk| chunk.text.clone()).collect(); let response = if contents.is_empty() { // skip detector call as contents is empty @@ -681,7 +681,7 @@ pub async fn detect_content( headers: HeaderMap, ) -> Result, Error> { let detector_id = detector_id.clone(); - let threshold = detector_params.threshold().unwrap_or(default_threshold); + let threshold = detector_params.pop_threshold().unwrap_or(default_threshold); let contents: Vec<_> = chunks.iter().map(|chunk| chunk.text.clone()).collect(); let response = if contents.is_empty() { // skip detector call as contents is empty @@ -737,8 +737,8 @@ pub async fn detect_for_generation( headers: HeaderMap, ) -> Result, Error> { let detector_id = detector_id.clone(); - let threshold = detector_params.threshold().unwrap_or( - detector_params.threshold().unwrap_or( + let threshold = detector_params.pop_threshold().unwrap_or( + detector_params.pop_threshold().unwrap_or( ctx.config .detectors .get(&detector_id) @@ -778,8 +778,8 @@ pub async fn detect_for_chat( headers: HeaderMap, ) -> Result, Error> { let detector_id = detector_id.clone(); - let threshold = detector_params.threshold().unwrap_or( - detector_params.threshold().unwrap_or( + let threshold = detector_params.pop_threshold().unwrap_or( + detector_params.pop_threshold().unwrap_or( ctx.config .detectors .get(&detector_id) @@ -821,8 +821,8 @@ pub async fn detect_for_context( headers: HeaderMap, ) -> Result, Error> { let detector_id = detector_id.clone(); - let threshold = detector_params.threshold().unwrap_or( - detector_params.threshold().unwrap_or( + let threshold = detector_params.pop_threshold().unwrap_or( + detector_params.pop_threshold().unwrap_or( ctx.config .detectors .get(&detector_id) From 6a58107cf268cc449d5ad0447e76166ecf86983a Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Wed, 23 Oct 2024 14:33:06 -0600 Subject: [PATCH 36/50] :sparkles::bug: Pass on non-threshold detector parameters (#235) * :sparkles: Pass along detector_params for text contents and generation Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :sparkles::white_check_mark: Not pass on threshold param Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :bug: Mutable threshold not intuitive Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * Update src/clients/detector/text_chat.rs Co-authored-by: Mateus Devino <19861348+mdevino@users.noreply.github.com> Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :bulb::art: Update params comment and remove unnecessary clones Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :white_check_mark: Threshold re-call not problematic Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --------- Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Co-authored-by: Mateus Devino <19861348+mdevino@users.noreply.github.com> --- src/clients/detector/text_chat.rs | 1 + src/clients/detector/text_contents.rs | 11 +++++-- src/clients/detector/text_context_doc.rs | 2 +- src/clients/detector/text_generation.rs | 8 +++-- src/models.rs | 4 ++- src/orchestrator/streaming.rs | 5 ++- src/orchestrator/unary.rs | 42 +++++++++++++++--------- 7 files changed, 51 insertions(+), 22 deletions(-) diff --git a/src/clients/detector/text_chat.rs b/src/clients/detector/text_chat.rs index 563d80f8..a4f0449f 100644 --- a/src/clients/detector/text_chat.rs +++ b/src/clients/detector/text_chat.rs @@ -109,6 +109,7 @@ pub struct ChatDetectionRequest { /// Chat messages to run detection on pub messages: Vec, + /// Detector parameters (available parameters depend on the detector) pub detector_params: DetectorParams, } diff --git a/src/clients/detector/text_contents.rs b/src/clients/detector/text_contents.rs index b8959c0f..1e2016fc 100644 --- a/src/clients/detector/text_contents.rs +++ b/src/clients/detector/text_contents.rs @@ -24,6 +24,7 @@ use crate::{ clients::{create_http_client, Client, Error, HttpClient}, config::ServiceConfig, health::HealthCheckResult, + models::DetectorParams, }; #[cfg_attr(test, faux::create)] @@ -106,11 +107,17 @@ impl Client for TextContentsDetectorClient { pub struct ContentAnalysisRequest { /// Field allowing users to provide list of documents for analysis pub contents: Vec, + + /// Detector parameters (available parameters depend on the detector) + pub detector_params: DetectorParams, } impl ContentAnalysisRequest { - pub fn new(contents: Vec) -> ContentAnalysisRequest { - ContentAnalysisRequest { contents } + pub fn new(contents: Vec, detector_params: DetectorParams) -> ContentAnalysisRequest { + ContentAnalysisRequest { + contents, + detector_params, + } } } diff --git a/src/clients/detector/text_context_doc.rs b/src/clients/detector/text_context_doc.rs index ae9973e8..ae18b317 100644 --- a/src/clients/detector/text_context_doc.rs +++ b/src/clients/detector/text_context_doc.rs @@ -114,7 +114,7 @@ pub struct ContextDocsDetectionRequest { /// Context to run detection on pub context: Vec, - // Detector Params + /// Detector parameters (available parameters depend on the detector) pub detector_params: DetectorParams, } diff --git a/src/clients/detector/text_generation.rs b/src/clients/detector/text_generation.rs index 98236dfe..2a9692d7 100644 --- a/src/clients/detector/text_generation.rs +++ b/src/clients/detector/text_generation.rs @@ -24,7 +24,7 @@ use crate::{ clients::{create_http_client, Client, Error, HttpClient}, config::ServiceConfig, health::HealthCheckResult, - models::DetectionResult, + models::{DetectionResult, DetectorParams}, }; #[cfg_attr(test, faux::create)] @@ -110,13 +110,17 @@ pub struct GenerationDetectionRequest { /// Text generated from an LLM pub generated_text: String, + + /// Detector parameters (available parameters depend on the detector) + pub detector_params: DetectorParams, } impl GenerationDetectionRequest { - pub fn new(prompt: String, generated_text: String) -> Self { + pub fn new(prompt: String, generated_text: String, detector_params: DetectorParams) -> Self { Self { prompt, generated_text, + detector_params, } } } diff --git a/src/models.rs b/src/models.rs index b79bb120..dee1ffc5 100644 --- a/src/models.rs +++ b/src/models.rs @@ -31,6 +31,8 @@ use crate::{ pb, }; +pub const THRESHOLD_PARAM: &str = "threshold"; + #[derive(Clone, Debug, Serialize)] pub struct InfoResponse { pub services: HealthCheckCache, @@ -55,7 +57,7 @@ impl DetectorParams { /// Threshold to filter detector results by score. pub fn pop_threshold(&mut self) -> Option { - self.0.remove("threshold").and_then(|v| v.as_f64()) + self.0.remove(THRESHOLD_PARAM).and_then(|v| v.as_f64()) } } diff --git a/src/orchestrator/streaming.rs b/src/orchestrator/streaming.rs index dd0007af..aff0e42a 100644 --- a/src/orchestrator/streaming.rs +++ b/src/orchestrator/streaming.rs @@ -291,6 +291,7 @@ async fn streaming_output_detection_task( tokio::spawn(detection_task( ctx.clone(), detector_id.clone(), + detector_params, threshold, detector_tx, chunk_rx, @@ -354,9 +355,11 @@ async fn generation_broadcast_task( /// Consumes chunk broadcast stream, sends unary requests to a detector service, /// and sends chunk + responses to detection stream. #[instrument(skip_all)] +#[allow(clippy::too_many_arguments)] async fn detection_task( ctx: Arc, detector_id: String, + detector_params: DetectorParams, threshold: f64, detector_tx: mpsc::Sender<(Chunk, Detections)>, mut chunk_rx: broadcast::Receiver, @@ -382,7 +385,7 @@ async fn detection_task( debug!("empty chunk, skipping detector request."); break; } else { - let request = ContentAnalysisRequest::new(contents.clone()); + let request = ContentAnalysisRequest::new(contents.clone(), detector_params.clone()); let headers = headers.clone(); debug!(%detector_id, ?request, "sending detector request"); let client = ctx diff --git a/src/orchestrator/unary.rs b/src/orchestrator/unary.rs index 1a496c21..d9b0a03a 100644 --- a/src/orchestrator/unary.rs +++ b/src/orchestrator/unary.rs @@ -628,7 +628,7 @@ pub async fn detect( // skip detector call as contents is empty Vec::default() } else { - let request = ContentAnalysisRequest::new(contents); + let request = ContentAnalysisRequest::new(contents, detector_params); debug!(%detector_id, ?request, "sending detector request"); let client = ctx .clients @@ -687,7 +687,7 @@ pub async fn detect_content( // skip detector call as contents is empty Vec::default() } else { - let request = ContentAnalysisRequest::new(contents); + let request = ContentAnalysisRequest::new(contents, detector_params); debug!(%detector_id, ?request, "sending detector request"); let client = ctx .clients @@ -746,7 +746,8 @@ pub async fn detect_for_generation( .default_threshold, ), ); - let request = GenerationDetectionRequest::new(prompt.clone(), generated_text.clone()); + let request = + GenerationDetectionRequest::new(prompt.clone(), generated_text.clone(), detector_params); debug!(%detector_id, ?request, "sending generation detector request"); let client = ctx .clients @@ -787,7 +788,7 @@ pub async fn detect_for_chat( .default_threshold, ), ); - let request = ChatDetectionRequest::new(messages.clone(), detector_params.clone()); + let request = ChatDetectionRequest::new(messages.clone(), detector_params); debug!(%detector_id, ?request, "sending chat detector request"); let client = ctx .clients @@ -966,7 +967,7 @@ mod tests { ClientMap, GenerationClient, TgisClient, }, config::{DetectorConfig, OrchestratorConfig}, - models::{DetectionResult, EvidenceObj, FinishReason}, + models::{DetectionResult, EvidenceObj, FinishReason, THRESHOLD_PARAM}, pb::fmaas::{ BatchedGenerationRequest, BatchedGenerationResponse, GenerationRequest, GenerationResponse, StopReason, @@ -1050,7 +1051,7 @@ mod tests { let first_sentence = "I don't like potatoes.".to_string(); let second_sentence = "I hate aliens.".to_string(); let mut detector_params = DetectorParams::new(); - detector_params.insert("threshold".into(), threshold.into()); + detector_params.insert(THRESHOLD_PARAM.into(), threshold.into()); let chunks = vec![ Chunk { offset: 0, @@ -1075,7 +1076,10 @@ mod tests { faux::when!(detector_client.text_contents( detector_id, - ContentAnalysisRequest::new(vec![first_sentence.clone(), second_sentence.clone()]), + ContentAnalysisRequest::new( + vec![first_sentence.clone(), second_sentence.clone()], + DetectorParams::new() + ), HeaderMap::new(), )) .once() @@ -1130,7 +1134,7 @@ mod tests { let sentence = "This call will return a 503.".to_string(); let threshold = 0.5; let mut detector_params = DetectorParams::new(); - detector_params.insert("threshold".into(), threshold.into()); + detector_params.insert(THRESHOLD_PARAM.into(), threshold.into()); let chunks = vec![Chunk { offset: 0, text: sentence.clone(), @@ -1147,7 +1151,7 @@ mod tests { faux::when!(detector_client.text_contents( detector_id, - ContentAnalysisRequest::new(vec![sentence.clone()]), + ContentAnalysisRequest::new(vec![sentence.clone()], DetectorParams::new()), HeaderMap::new(), )) .once() @@ -1185,7 +1189,7 @@ mod tests { let threshold = 0.5; let first_sentence = "".to_string(); let mut detector_params = DetectorParams::new(); - detector_params.insert("threshold".into(), threshold.into()); + detector_params.insert(THRESHOLD_PARAM.into(), threshold.into()); let chunks = vec![Chunk { offset: 0, text: first_sentence.clone(), @@ -1193,7 +1197,7 @@ mod tests { faux::when!(detector_client.text_contents( detector_id, - ContentAnalysisRequest::new(vec![first_sentence.clone()]), + ContentAnalysisRequest::new(vec![first_sentence.clone()], DetectorParams::new()), HeaderMap::new(), )) .once() @@ -1230,7 +1234,7 @@ mod tests { let prompt = "What is the capital of Brazil?".to_string(); let generated_text = "The capital of Brazil is Brasilia.".to_string(); let mut detector_params = DetectorParams::new(); - detector_params.insert("threshold".into(), threshold.into()); + detector_params.insert(THRESHOLD_PARAM.into(), threshold.into()); let expected_response: Vec = vec![DetectionResult { detection_type: "relevance".to_string(), @@ -1249,7 +1253,11 @@ mod tests { faux::when!(detector_client.text_generation( detector_id, - GenerationDetectionRequest::new(prompt.clone(), generated_text.clone()), + GenerationDetectionRequest::new( + prompt.clone(), + generated_text.clone(), + DetectorParams::new() + ), HeaderMap::new(), )) .once() @@ -1307,13 +1315,17 @@ mod tests { let generated_text = "The most beautiful places can be found in Rio de Janeiro.".to_string(); let mut detector_params = DetectorParams::new(); - detector_params.insert("threshold".into(), threshold.into()); + detector_params.insert(THRESHOLD_PARAM.into(), threshold.into()); let expected_response: Vec = vec![]; faux::when!(detector_client.text_generation( detector_id, - GenerationDetectionRequest::new(prompt.clone(), generated_text.clone()), + GenerationDetectionRequest::new( + prompt.clone(), + generated_text.clone(), + DetectorParams::new() + ), HeaderMap::new(), )) .once() From 4aa4c8781a36ae38cc88ec24ac6547097cc50479 Mon Sep 17 00:00:00 2001 From: Paul Scoropan <1paulscoropan@gmail.com> Date: Tue, 15 Oct 2024 19:02:33 -0400 Subject: [PATCH 37/50] client instrumentation and replacing request_id with trace_id Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> --- src/clients.rs | 12 +++ src/clients/chunker.rs | 6 +- src/clients/detector/text_chat.rs | 8 +- src/clients/detector/text_contents.rs | 3 + src/clients/detector/text_context_doc.rs | 3 + src/clients/detector/text_generation.rs | 3 + src/clients/generation.rs | 31 +++++--- src/clients/nlp.rs | 20 +++-- src/clients/tgis.rs | 8 ++ src/orchestrator.rs | 46 +++++------ src/orchestrator/streaming.rs | 54 +++++++------ src/orchestrator/unary.rs | 97 +++++++++++++++--------- src/server.rs | 35 +++++---- src/tracing_utils.rs | 45 +++++++---- 14 files changed, 238 insertions(+), 133 deletions(-) diff --git a/src/clients.rs b/src/clients.rs index 392eac05..992da6a8 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -27,6 +27,7 @@ use async_trait::async_trait; use futures::Stream; use ginepro::LoadBalancedChannel; use tokio::{fs::File, io::AsyncReadExt}; +use tracing::{debug, instrument}; use url::Url; use crate::{ @@ -193,6 +194,7 @@ impl ClientMap { } } +#[instrument(skip_all, fields(hostname = service_config.hostname))] pub async fn create_http_client(default_port: u16, service_config: &ServiceConfig) -> HttpClient { let port = service_config.port.unwrap_or(default_port); let protocol = match service_config.tls { @@ -201,6 +203,7 @@ pub async fn create_http_client(default_port: u16, service_config: &ServiceConfi }; let mut base_url = Url::parse(&format!("{}://{}", protocol, &service_config.hostname)).unwrap(); base_url.set_port(Some(port)).unwrap(); + debug!(%base_url, "creating HTTP client"); let request_timeout = Duration::from_secs( service_config .request_timeout @@ -250,11 +253,20 @@ pub async fn create_http_client(default_port: u16, service_config: &ServiceConfi HttpClient::new(base_url, client) } +#[instrument(skip_all, fields(hostname = service_config.hostname))] pub async fn create_grpc_client( default_port: u16, service_config: &ServiceConfig, new: fn(LoadBalancedChannel) -> C, ) -> C { + let port = service_config.port.unwrap_or(default_port); + let protocol = match service_config.tls { + Some(_) => "https", + None => "http", + }; + let mut base_url = Url::parse(&format!("{}://{}", protocol, &service_config.hostname)).unwrap(); + base_url.set_port(Some(port)).unwrap(); + debug!(%base_url, "creating gRPC client"); let request_timeout = Duration::from_secs( service_config .request_timeout diff --git a/src/clients/chunker.rs b/src/clients/chunker.rs index 65795653..1e877015 100644 --- a/src/clients/chunker.rs +++ b/src/clients/chunker.rs @@ -21,6 +21,7 @@ use async_trait::async_trait; use futures::{Future, Stream, StreamExt, TryStreamExt}; use ginepro::LoadBalancedChannel; use tonic::{Code, Request, Response, Status, Streaming}; +use tracing::{info, instrument}; use super::{create_grpc_client, errors::grpc_to_http_code, BoxStream, Client, Error}; use crate::{ @@ -62,6 +63,7 @@ impl ChunkerClient { } } + #[instrument(skip_all, fields(model_id))] pub async fn tokenization_task_predict( &self, model_id: &str, @@ -69,17 +71,20 @@ impl ChunkerClient { ) -> Result { let mut client = self.client.clone(); let request = request_with_model_id(request, model_id); + info!(?request, "sending client request"); Ok(client .chunker_tokenization_task_predict(request) .await? .into_inner()) } + #[instrument(skip_all, fields(model_id))] pub async fn bidi_streaming_tokenization_task_predict( &self, model_id: &str, request_stream: BoxStream, ) -> Result>, Error> { + info!("sending client stream request"); let mut client = self.client.clone(); let request = request_with_model_id(request_stream, model_id); // NOTE: this is an ugly workaround to avoid bogus higher-ranked lifetime errors. @@ -91,7 +96,6 @@ impl ChunkerClient { .into_inner() .map_err(Into::into) .boxed(); - Ok(response_stream) } } diff --git a/src/clients/detector/text_chat.rs b/src/clients/detector/text_chat.rs index a4f0449f..aef3a0c9 100644 --- a/src/clients/detector/text_chat.rs +++ b/src/clients/detector/text_chat.rs @@ -18,6 +18,7 @@ use async_trait::async_trait; use hyper::{HeaderMap, StatusCode}; use serde::Serialize; +use tracing::{debug, info, instrument}; use super::{DetectorError, DEFAULT_PORT, DETECTOR_ID_HEADER_NAME}; use crate::{ @@ -51,6 +52,7 @@ impl TextChatDetectorClient { } } + #[instrument(skip_all, fields(model_id, ?headers))] pub async fn text_chat( &self, model_id: &str, @@ -64,10 +66,10 @@ impl TextChatDetectorClient { .headers(headers) .header(DETECTOR_ID_HEADER_NAME, model_id) .json(&request); - - tracing::debug!("Request being sent to chat detector: {:?}", request); + info!(?url, "sending chat detector client request"); + debug!("chat detector client request: {:?}", request); let response = request.send().await?; - tracing::debug!("Response received from chat detector: {:?}", response); + debug!("chat detector client response: {:?}", response); if response.status() == StatusCode::OK { Ok(response.json().await?) diff --git a/src/clients/detector/text_contents.rs b/src/clients/detector/text_contents.rs index 1e2016fc..d7df440e 100644 --- a/src/clients/detector/text_contents.rs +++ b/src/clients/detector/text_contents.rs @@ -18,6 +18,7 @@ use async_trait::async_trait; use hyper::{HeaderMap, StatusCode}; use serde::{Deserialize, Serialize}; +use tracing::{info, instrument}; use super::{DetectorError, DEFAULT_PORT, DETECTOR_ID_HEADER_NAME}; use crate::{ @@ -49,6 +50,7 @@ impl TextContentsDetectorClient { } } + #[instrument(skip_all, fields(model_id, ?headers))] pub async fn text_contents( &self, model_id: &str, @@ -60,6 +62,7 @@ impl TextContentsDetectorClient { .base_url() .join("/api/v1/text/contents") .unwrap(); + info!(?url, ?request, "sending client request"); let response = self .client .post(url) diff --git a/src/clients/detector/text_context_doc.rs b/src/clients/detector/text_context_doc.rs index ae18b317..e73a985f 100644 --- a/src/clients/detector/text_context_doc.rs +++ b/src/clients/detector/text_context_doc.rs @@ -18,6 +18,7 @@ use async_trait::async_trait; use hyper::{HeaderMap, StatusCode}; use serde::{Deserialize, Serialize}; +use tracing::{info, instrument}; use super::{DetectorError, DEFAULT_PORT, DETECTOR_ID_HEADER_NAME}; use crate::{ @@ -49,6 +50,7 @@ impl TextContextDocDetectorClient { } } + #[instrument(skip_all, fields(model_id, ?headers))] pub async fn text_context_doc( &self, model_id: &str, @@ -60,6 +62,7 @@ impl TextContextDocDetectorClient { .base_url() .join("/api/v1/text/context/doc") .unwrap(); + info!(?url, ?request, "sending client request"); let response = self .client .post(url) diff --git a/src/clients/detector/text_generation.rs b/src/clients/detector/text_generation.rs index 2a9692d7..2486cea0 100644 --- a/src/clients/detector/text_generation.rs +++ b/src/clients/detector/text_generation.rs @@ -18,6 +18,7 @@ use async_trait::async_trait; use hyper::{HeaderMap, StatusCode}; use serde::Serialize; +use tracing::{info, instrument}; use super::{DetectorError, DEFAULT_PORT, DETECTOR_ID_HEADER_NAME}; use crate::{ @@ -49,6 +50,7 @@ impl TextGenerationDetectorClient { } } + #[instrument(skip_all, fields(model_id, ?headers))] pub async fn text_generation( &self, model_id: &str, @@ -60,6 +62,7 @@ impl TextGenerationDetectorClient { .base_url() .join("/api/v1/text/generation") .unwrap(); + info!(?url, ?request, "sending client request"); let response = self .client .post(url) diff --git a/src/clients/generation.rs b/src/clients/generation.rs index c10520dd..1b514dac 100644 --- a/src/clients/generation.rs +++ b/src/clients/generation.rs @@ -18,7 +18,7 @@ use async_trait::async_trait; use futures::{StreamExt, TryStreamExt}; use hyper::HeaderMap; -use tracing::debug; +use tracing::{debug, instrument}; use super::{BoxStream, Client, Error, NlpClient, TgisClient}; use crate::{ @@ -63,6 +63,7 @@ impl GenerationClient { Self(None) } + #[instrument(skip_all, fields(model_id, ?headers))] pub async fn tokenize( &self, model_id: String, @@ -78,19 +79,19 @@ impl GenerationClient { return_offsets: false, truncate_input_tokens: 0, }; - debug!(%model_id, provider = "tgis", ?request, "sending tokenize request"); + debug!(provider = "tgis", ?request, "sending tokenize request"); let mut response = client.tokenize(request, headers).await?; - debug!(%model_id, provider = "tgis", ?response, "received tokenize response"); + debug!(provider = "tgis", ?response, "received tokenize response"); let response = response.responses.swap_remove(0); Ok((response.token_count, response.tokens)) } Some(GenerationClientInner::Nlp(client)) => { let request = TokenizationTaskRequest { text }; - debug!(%model_id, provider = "nlp", ?request, "sending tokenize request"); + debug!(provider = "nlp", ?request, "sending tokenize request"); let response = client .tokenization_task_predict(&model_id, request, headers) .await?; - debug!(%model_id, provider = "nlp", ?response, "received tokenize response"); + debug!(provider = "nlp", ?response, "received tokenize response"); let tokens = response .results .into_iter() @@ -118,9 +119,9 @@ impl GenerationClient { requests: vec![GenerationRequest { text }], params, }; - debug!(%model_id, provider = "tgis", ?request, "sending generate request"); + debug!(provider = "tgis", ?request, "sending generate request"); let response = client.generate(request, headers).await?; - debug!(%model_id, provider = "tgis", ?response, "received generate response"); + debug!(provider = "tgis", ?response, "received generate response"); Ok(response.into()) } Some(GenerationClientInner::Nlp(client)) => { @@ -155,11 +156,11 @@ impl GenerationClient { ..Default::default() } }; - debug!(%model_id, provider = "nlp", ?request, "sending generate request"); + debug!(provider = "nlp", ?request, "sending generate request"); let response = client .text_generation_task_predict(&model_id, request, headers) .await?; - debug!(%model_id, provider = "nlp", ?response, "received generate response"); + debug!(provider = "nlp", ?response, "received generate response"); Ok(response.into()) } None => Err(Error::ModelNotFound { model_id }), @@ -182,7 +183,11 @@ impl GenerationClient { request: Some(GenerationRequest { text }), params, }; - debug!(%model_id, provider = "tgis", ?request, "sending generate_stream request"); + debug!( + provider = "tgis", + ?request, + "sending generate_stream request" + ); let response_stream = client .generate_stream(request, headers) .await? @@ -222,7 +227,11 @@ impl GenerationClient { ..Default::default() } }; - debug!(%model_id, provider = "nlp", ?request, "sending generate_stream request"); + debug!( + provider = "nlp", + ?request, + "sending generate_stream request" + ); let response_stream = client .server_streaming_text_generation_task_predict(&model_id, request, headers) .await? diff --git a/src/clients/nlp.rs b/src/clients/nlp.rs index f6c873a6..bbceb29d 100644 --- a/src/clients/nlp.rs +++ b/src/clients/nlp.rs @@ -15,12 +15,6 @@ */ -use async_trait::async_trait; -use axum::http::{Extensions, HeaderMap}; -use futures::{StreamExt, TryStreamExt}; -use ginepro::LoadBalancedChannel; -use tonic::{metadata::MetadataMap, Code, Request}; - use super::{create_grpc_client, errors::grpc_to_http_code, BoxStream, Client, Error}; use crate::{ config::ServiceConfig, @@ -37,6 +31,12 @@ use crate::{ grpc::health::v1::{health_client::HealthClient, HealthCheckRequest}, }, }; +use async_trait::async_trait; +use axum::http::{Extensions, HeaderMap}; +use futures::{StreamExt, TryStreamExt}; +use ginepro::LoadBalancedChannel; +use tonic::{metadata::MetadataMap, Code, Request}; +use tracing::{info, instrument}; const DEFAULT_PORT: u16 = 8085; const MODEL_ID_HEADER_NAME: &str = "mm-model-id"; @@ -59,6 +59,7 @@ impl NlpClient { } } + #[instrument(skip_all, fields(model_id, ?headers))] pub async fn tokenization_task_predict( &self, model_id: &str, @@ -67,12 +68,14 @@ impl NlpClient { ) -> Result { let mut client = self.client.clone(); let request = request_with_model_id(request, model_id, headers); + info!(?request, "sending request to NLP gRPC service"); Ok(client .tokenization_task_predict(request) .await? .into_inner()) } + #[instrument(skip_all, fields(model_id, ?headers))] pub async fn token_classification_task_predict( &self, model_id: &str, @@ -81,12 +84,14 @@ impl NlpClient { ) -> Result { let mut client = self.client.clone(); let request = request_with_model_id(request, model_id, headers); + info!(?request, "sending request to NLP gRPC service"); Ok(client .token_classification_task_predict(request) .await? .into_inner()) } + #[instrument(skip_all, fields(model_id, ?headers))] pub async fn text_generation_task_predict( &self, model_id: &str, @@ -95,12 +100,14 @@ impl NlpClient { ) -> Result { let mut client = self.client.clone(); let request = request_with_model_id(request, model_id, headers); + info!(?request, "sending request to NLP gRPC service"); Ok(client .text_generation_task_predict(request) .await? .into_inner()) } + #[instrument(skip_all, fields(model_id, ?headers))] pub async fn server_streaming_text_generation_task_predict( &self, model_id: &str, @@ -109,6 +116,7 @@ impl NlpClient { ) -> Result>, Error> { let mut client = self.client.clone(); let request = request_with_model_id(request, model_id, headers); + info!(?request, "sending stream request to NLP gRPC service"); let response_stream = client .server_streaming_text_generation_task_predict(request) .await? diff --git a/src/clients/tgis.rs b/src/clients/tgis.rs index 09d01395..cf07493c 100644 --- a/src/clients/tgis.rs +++ b/src/clients/tgis.rs @@ -20,6 +20,7 @@ use axum::http::HeaderMap; use futures::{StreamExt, TryStreamExt}; use ginepro::LoadBalancedChannel; use tonic::Code; +use tracing::{info, instrument}; use super::{create_grpc_client, errors::grpc_to_http_code, BoxStream, Client, Error}; use crate::{ @@ -47,20 +48,24 @@ impl TgisClient { Self { client } } + #[instrument(skip_all, fields(model_id = request.model_id, headers = ?_headers))] pub async fn generate( &self, request: BatchedGenerationRequest, _headers: HeaderMap, ) -> Result { + info!(?request, "sending request to TGIS gRPC service"); let mut client = self.client.clone(); Ok(client.generate(request).await?.into_inner()) } + #[instrument(skip_all, fields(model_id = request.model_id, headers = ?_headers))] pub async fn generate_stream( &self, request: SingleGenerationRequest, _headers: HeaderMap, ) -> Result>, Error> { + info!(?request, "sending request to TGIS gRPC service"); let mut client = self.client.clone(); let response_stream = client .generate_stream(request) @@ -71,16 +76,19 @@ impl TgisClient { Ok(response_stream) } + #[instrument(skip_all, fields(model_id = request.model_id, headers = ?_headers))] pub async fn tokenize( &self, request: BatchedTokenizeRequest, _headers: HeaderMap, ) -> Result { + info!(?request, "sending request to TGIS gRPC service"); let mut client = self.client.clone(); Ok(client.tokenize(request).await?.into_inner()) } pub async fn model_info(&self, request: ModelInfoRequest) -> Result { + info!(?request, "sending request to TGIS gRPC service"); let mut client = self.client.clone(); Ok(client.model_info(request).await?.into_inner()) } diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 98c9e164..dda82b7d 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -23,9 +23,9 @@ pub mod unary; use std::{collections::HashMap, sync::Arc}; use axum::http::header::HeaderMap; +use opentelemetry::trace::TraceId; use tokio::{sync::RwLock, time::Instant}; use tracing::{debug, info}; -use uuid::Uuid; use crate::{ clients::{ @@ -257,7 +257,7 @@ pub struct Chunk { #[derive(Debug)] pub struct ClassificationWithGenTask { - pub request_id: Uuid, + pub trace_id: TraceId, pub model_id: String, pub inputs: String, pub guardrails_config: GuardrailsConfig, @@ -266,9 +266,9 @@ pub struct ClassificationWithGenTask { } impl ClassificationWithGenTask { - pub fn new(request_id: Uuid, request: GuardrailsHttpRequest, headers: HeaderMap) -> Self { + pub fn new(trace_id: TraceId, request: GuardrailsHttpRequest, headers: HeaderMap) -> Self { Self { - request_id, + trace_id, model_id: request.model_id, inputs: request.inputs, guardrails_config: request.guardrail_config.unwrap_or_default(), @@ -281,8 +281,8 @@ impl ClassificationWithGenTask { /// Task for the /api/v2/text/detection/content endpoint #[derive(Debug)] pub struct GenerationWithDetectionTask { - /// Request unique identifier - pub request_id: Uuid, + /// Unique identifier of request trace + pub trace_id: TraceId, /// Model ID of the LLM pub model_id: String, @@ -302,12 +302,12 @@ pub struct GenerationWithDetectionTask { impl GenerationWithDetectionTask { pub fn new( - request_id: Uuid, + trace_id: TraceId, request: GenerationWithDetectionHttpRequest, headers: HeaderMap, ) -> Self { Self { - request_id, + trace_id, model_id: request.model_id, prompt: request.prompt, detectors: request.detectors, @@ -320,8 +320,8 @@ impl GenerationWithDetectionTask { /// Task for the /api/v2/text/detection/content endpoint #[derive(Debug)] pub struct TextContentDetectionTask { - /// Request unique identifier - pub request_id: Uuid, + /// Unique identifier of request trace + pub trace_id: TraceId, /// Content to run detection on pub content: String, @@ -335,12 +335,12 @@ pub struct TextContentDetectionTask { impl TextContentDetectionTask { pub fn new( - request_id: Uuid, + trace_id: TraceId, request: TextContentDetectionHttpRequest, headers: HeaderMap, ) -> Self { Self { - request_id, + trace_id, content: request.content, detectors: request.detectors, headers, @@ -351,8 +351,8 @@ impl TextContentDetectionTask { /// Task for the /api/v1/text/task/detection/context endpoint #[derive(Debug)] pub struct ContextDocsDetectionTask { - /// Request unique identifier - pub request_id: Uuid, + /// Unique identifier of request trace + pub trace_id: TraceId, /// Content to run detection on pub content: String, @@ -371,9 +371,9 @@ pub struct ContextDocsDetectionTask { } impl ContextDocsDetectionTask { - pub fn new(request_id: Uuid, request: ContextDocsHttpRequest, headers: HeaderMap) -> Self { + pub fn new(trace_id: TraceId, request: ContextDocsHttpRequest, headers: HeaderMap) -> Self { Self { - request_id, + trace_id, content: request.content, context_type: request.context_type, context: request.context, @@ -413,8 +413,8 @@ impl ChatDetectionTask { /// Task for the /api/v2/text/detection/generated endpoint #[derive(Debug)] pub struct DetectionOnGenerationTask { - /// Request unique identifier - pub request_id: Uuid, + /// Unique identifier of request trace + pub trace_id: TraceId, /// User prompt to be sent to the LLM pub prompt: String, @@ -431,12 +431,12 @@ pub struct DetectionOnGenerationTask { impl DetectionOnGenerationTask { pub fn new( - request_id: Uuid, + trace_id: TraceId, request: DetectionOnGeneratedHttpRequest, headers: HeaderMap, ) -> Self { Self { - request_id, + trace_id, prompt: request.prompt, generated_text: request.generated_text, detectors: request.detectors, @@ -448,7 +448,7 @@ impl DetectionOnGenerationTask { #[allow(dead_code)] #[derive(Debug)] pub struct StreamingClassificationWithGenTask { - pub request_id: Uuid, + pub trace_id: TraceId, pub model_id: String, pub inputs: String, pub guardrails_config: GuardrailsConfig, @@ -457,9 +457,9 @@ pub struct StreamingClassificationWithGenTask { } impl StreamingClassificationWithGenTask { - pub fn new(request_id: Uuid, request: GuardrailsHttpRequest, headers: HeaderMap) -> Self { + pub fn new(trace_id: TraceId, request: GuardrailsHttpRequest, headers: HeaderMap) -> Self { Self { - request_id, + trace_id, model_id: request.model_id, inputs: request.inputs, guardrails_config: request.guardrail_config.unwrap_or_default(), diff --git a/src/orchestrator/streaming.rs b/src/orchestrator/streaming.rs index aff0e42a..be71648e 100644 --- a/src/orchestrator/streaming.rs +++ b/src/orchestrator/streaming.rs @@ -50,20 +50,20 @@ pub type Detections = Vec; impl Orchestrator { /// Handles streaming tasks. - #[instrument(name = "stream_handler", skip_all)] + #[instrument(skip_all, fields(trace_id = task.trace_id.to_string(), model_id = task.model_id, headers = ?task.headers))] pub async fn handle_streaming_classification_with_gen( &self, task: StreamingClassificationWithGenTask, ) -> ReceiverStream> { + info!(config = ?task.guardrails_config, "starting task"); + let ctx = self.ctx.clone(); - let request_id = task.request_id; + let trace_id = task.trace_id; let model_id = task.model_id; let params = task.text_gen_parameters; let input_text = task.inputs; let headers = task.headers; - info!(%request_id, config = ?task.guardrails_config, "starting task"); - // Create response channel #[allow(clippy::type_complexity)] let (response_tx, response_rx): ( @@ -88,7 +88,7 @@ impl Orchestrator { { Ok(result) => result, Err(error) => { - error!(%request_id, %error, "task failed"); + error!(%trace_id, %error, "task failed"); let _ = response_tx.send(Err(error)).await; return; } @@ -96,7 +96,7 @@ impl Orchestrator { } _ => None, }; - debug!(?input_detections); + debug!(?input_detections); // TODO: metrics if let Some(mut input_detections) = input_detections { // Detected HAP/PII // Do tokenization to get input_token_count @@ -106,7 +106,7 @@ impl Orchestrator { { Ok(result) => result, Err(error) => { - error!(%request_id, %error, "task failed"); + error!(%trace_id, %error, "task failed"); let _ = response_tx.send(Err(error)).await; return; } @@ -141,7 +141,7 @@ impl Orchestrator { { Ok(generation_stream) => generation_stream, Err(error) => { - error!(%request_id, %error, "task failed"); + error!(%trace_id, %error, "task failed"); let _ = response_tx.send(Err(error)).await; return; } @@ -171,7 +171,7 @@ impl Orchestrator { { Ok(result_rx) => result_rx, Err(error) => { - error!(%request_id, %error, "task failed"); + error!(%trace_id, %error, "task failed"); let _ = error_tx.send(error.clone()); let _ = response_tx.send(Err(error)).await; return; @@ -183,19 +183,19 @@ impl Orchestrator { loop { tokio::select! { Ok(error) = error_rx.recv() => { - error!(%request_id, %error, "task failed"); - debug!(%request_id, "sending error to client and terminating"); + error!(%trace_id, %error, "task failed"); + debug!(%trace_id, "sending error to client and terminating"); let _ = response_tx.send(Err(error)).await; return; }, result = result_rx.recv() => { match result { Some(result) => { - debug!(%request_id, ?result, "sending result to client"); + debug!(%trace_id, ?result, "sending result to client"); let _ = response_tx.send(result).await; }, None => { - info!(%request_id, "task completed: stream closed"); + info!(%trace_id, "task completed: stream closed"); break; }, } @@ -208,10 +208,10 @@ impl Orchestrator { // No output detectors, forward generation results to response channel tokio::spawn(async move { while let Some(result) = generation_stream.next().await { - debug!(%request_id, ?result, "sending result to client"); + debug!(%trace_id, ?result, "sending result to client"); let _ = response_tx.send(result).await; } - debug!(%request_id, "task completed: stream closed"); + debug!(%trace_id, "task completed: stream closed"); }); } } @@ -232,10 +232,11 @@ async fn streaming_output_detection_task( error_tx: broadcast::Sender, headers: HeaderMap, ) -> Result>, Error> { + debug!(?detectors, "creating chunk broadcast streams"); + // Create generation broadcast stream let (generation_tx, generation_rx) = broadcast::channel(1024); - debug!("creating chunk broadcast streams"); let chunker_ids = get_chunker_ids(ctx, detectors)?; // Create a map of chunker_id->chunk_broadcast_stream // This is to enable fan-out of chunk streams to potentially multiple detectors that use the same chunker. @@ -325,6 +326,7 @@ async fn generation_broadcast_task( generation_tx: broadcast::Sender, error_tx: broadcast::Sender, ) { + debug!("forwarding response stream"); let mut error_rx = error_tx.subscribe(); loop { tokio::select! { @@ -354,8 +356,8 @@ async fn generation_broadcast_task( /// Wraps a unary detector service to make it streaming. /// Consumes chunk broadcast stream, sends unary requests to a detector service, /// and sends chunk + responses to detection stream. -#[instrument(skip_all)] #[allow(clippy::too_many_arguments)] +#[instrument(skip_all, fields(detector_id))] async fn detection_task( ctx: Arc, detector_id: String, @@ -366,6 +368,7 @@ async fn detection_task( error_tx: broadcast::Sender, headers: HeaderMap, ) { + debug!(threshold, "starting task"); let mut error_rx = error_tx.subscribe(); loop { @@ -433,7 +436,7 @@ async fn detection_task( /// Opens bi-directional stream to a chunker service /// with generation stream input and returns chunk broadcast stream. -#[instrument(skip_all)] +#[instrument(skip_all, fields(chunker_id))] async fn chunk_broadcast_task( ctx: Arc, chunker_id: String, @@ -441,7 +444,7 @@ async fn chunk_broadcast_task( error_tx: broadcast::Sender, ) -> Result, Error> { // Consume generation stream and convert to chunker input stream - debug!(%chunker_id, "creating chunker input stream"); + debug!("creating chunker input stream"); // NOTE: Text gen providers can return more than 1 token in single stream object. This can create // edge cases where the enumeration generated below may not line up with token / response boundaries. // So the more accurate way here might be to use `Tokens` object from response, but since that is an @@ -459,7 +462,7 @@ async fn chunk_broadcast_task( } }) .boxed(); - debug!(%chunker_id, "creating chunker output stream"); + debug!("creating chunker output stream"); let id = chunker_id.clone(); // workaround for StreamExt::map_err let response_stream = if chunker_id == DEFAULT_CHUNKER_ID { @@ -490,7 +493,7 @@ async fn chunk_broadcast_task( }); // maps stream errors // Spawn task to consume output stream forward to broadcast channel - debug!(%chunker_id, "spawning chunker broadcast task"); + debug!("spawning chunker broadcast task"); let (chunk_tx, _) = broadcast::channel(1024); tokio::spawn({ let mut error_rx = error_tx.subscribe(); @@ -502,17 +505,17 @@ async fn chunk_broadcast_task( result = output_stream.next() => { match result { Some(Ok(chunk)) => { - debug!(%chunker_id, ?chunk, "received chunk"); + debug!(?chunk, "received chunk"); let _ = chunk_tx.send(chunk); }, Some(Err(error)) => { - error!(%chunker_id, %error, "chunker error, cancelling task"); + error!(%error, "chunker error, cancelling task"); let _ = error_tx.send(error); tokio::time::sleep(Duration::from_millis(5)).await; break; }, None => { - debug!(%chunker_id, "stream closed"); + debug!("stream closed"); break }, } @@ -525,6 +528,8 @@ async fn chunk_broadcast_task( } /// Sends generate stream request to a generation service. +#[allow(clippy::type_complexity)] +#[instrument(skip_all, fields(model_id))] async fn generate_stream( ctx: &Arc, model_id: String, @@ -535,6 +540,7 @@ async fn generate_stream( Pin> + Send>>, Error, > { + debug!(?params, "sending generate stream request"); let client = ctx .clients .get_as::("generation") diff --git a/src/orchestrator/unary.rs b/src/orchestrator/unary.rs index d9b0a03a..9a2d1a4d 100644 --- a/src/orchestrator/unary.rs +++ b/src/orchestrator/unary.rs @@ -56,15 +56,15 @@ const DEFAULT_STREAM_BUFFER_SIZE: usize = 5; impl Orchestrator { /// Handles unary tasks. - #[instrument(name = "unary_handler", skip_all)] + #[instrument(skip_all, fields(trace_id = ?task.trace_id, model_id = task.model_id, headers = ?task.headers))] pub async fn handle_classification_with_gen( &self, task: ClassificationWithGenTask, ) -> Result { let ctx = self.ctx.clone(); - let request_id = task.request_id; + let trace_id = task.trace_id; let headers = task.headers; - info!(%request_id, config = ?task.guardrails_config, "starting task"); + info!(config = ?task.guardrails_config, "handling classification with generation task"); let task_handle = tokio::spawn(async move { let input_text = task.inputs.clone(); let masks = task.guardrails_config.input_masks(); @@ -144,32 +144,31 @@ impl Orchestrator { match task_handle.await { // Task completed successfully Ok(Ok(result)) => { - debug!(%request_id, ?result, "sending result to client"); - info!(%request_id, "task completed"); + debug!(%trace_id, ?result, "sending result to client"); + info!(%trace_id, "task completed"); Ok(result) } // Task failed, return error propagated from child task that failed Ok(Err(error)) => { - error!(%request_id, %error, "task failed"); + error!(%trace_id, %error, "task failed"); Err(error) } // Task cancelled or panicked Err(error) => { let error = error.into(); - error!(%request_id, %error, "task failed"); + error!(%trace_id, %error, "task failed"); Err(error) } } } /// Handles the given generation task, followed by detections. + #[instrument(skip_all, fields(trace_id = ?task.trace_id, model_id = task.model_id, headers = ?task.headers))] pub async fn handle_generation_with_detection( &self, task: GenerationWithDetectionTask, ) -> Result { info!( - request_id = ?task.request_id, - model_id = %task.model_id, detectors = ?task.detectors, "handling generation with detection task" ); @@ -229,27 +228,25 @@ impl Orchestrator { Ok(Ok(result)) => Ok(result), // Task failed, return error propagated from child task that failed Ok(Err(error)) => { - error!(request_id = ?task.request_id, %error, "generation with detection unary task failed"); + error!(trace_id = ?task.trace_id, %error, "generation with detection unary task failed"); Err(error) } // Task cancelled or panicked Err(error) => { let error = error.into(); - error!(request_id = ?task.request_id, %error, "generation with detection unary task failed"); + error!(trace_id = ?task.trace_id, %error, "generation with detection unary task failed"); Err(error) } } } /// Handles detection on textual content + #[instrument(skip_all, fields(trace_id = ?task.trace_id, headers = ?task.headers))] pub async fn handle_text_content_detection( &self, task: TextContentDetectionTask, ) -> Result { - info!( - request_id = ?task.request_id, - "handling text content detection task" - ); + info!("handling text content detection task"); let ctx = self.ctx.clone(); let headers = task.headers; @@ -311,25 +308,25 @@ impl Orchestrator { Ok(Ok(result)) => Ok(result), // Task failed, return error propagated from child task that failed Ok(Err(error)) => { - error!(request_id = ?task.request_id, %error, "text content detection task failed"); + error!(trace_id = ?task.trace_id, %error, "text content detection task failed"); Err(error) } // Task cancelled or panicked Err(error) => { let error = error.into(); - error!(request_id = ?task.request_id, %error, "text content detection task failed"); + error!(trace_id = ?task.trace_id, %error, "text content detection task failed"); Err(error) } } } /// Handles context-related detections on textual content + #[instrument(skip_all, fields(trace_id = ?task.trace_id, headers = ?task.headers))] pub async fn handle_context_documents_detection( &self, task: ContextDocsDetectionTask, ) -> Result { info!( - request_id = ?task.request_id, detectors = ?task.detectors, "handling context documents detection task" ); @@ -376,25 +373,25 @@ impl Orchestrator { Ok(Ok(result)) => Ok(result), // Task failed, return error propagated from child task that failed Ok(Err(error)) => { - error!(request_id = ?task.request_id, %error, "context documents detection task failed"); + error!(trace_id = ?task.trace_id, %error, "context documents detection task failed"); Err(error) } // Task cancelled or panicked Err(error) => { let error = error.into(); - error!(request_id = ?task.request_id, %error, "context documents detection task failed"); + error!(trace_id = ?task.trace_id, %error, "context documents detection task failed"); Err(error) } } } /// Handles detections on generated text (without performing generation) + #[instrument(skip_all, fields(trace_id = ?task.trace_id, headers = ?task.headers))] pub async fn handle_generated_text_detection( &self, task: DetectionOnGenerationTask, ) -> Result { info!( - request_id = ?task.request_id, detectors = ?task.detectors, "handling detection on generated content task" ); @@ -439,13 +436,13 @@ impl Orchestrator { Ok(Ok(result)) => Ok(result), // Task failed, return error propagated from child task that failed Ok(Err(error)) => { - error!(request_id = ?task.request_id, %error, "detection on generated content task failed"); + error!(trace_id = ?task.trace_id, %error, "detection on generated content task failed"); Err(error) } // Task cancelled or panicked Err(error) => { let error = error.into(); - error!(request_id = ?task.request_id, %error, "detection on generated content task failed"); + error!(trace_id = ?task.trace_id, %error, "detection on generated content task failed"); Err(error) } } @@ -516,6 +513,7 @@ pub async fn input_detection_task( masks: Option<&[(usize, usize)]>, headers: HeaderMap, ) -> Result>, Error> { + debug!(?detectors, "starting input detection"); let text_with_offsets = apply_masks(input_text, masks); let chunker_ids = get_chunker_ids(ctx, detectors)?; let chunks = chunk_task(ctx, chunker_ids, text_with_offsets).await?; @@ -531,6 +529,7 @@ async fn output_detection_task( generated_text: String, headers: HeaderMap, ) -> Result>, Error> { + debug!(detectors = ?detectors.keys(), "starting output detection"); let text_with_offsets = apply_masks(generated_text, None); let chunker_ids = get_chunker_ids(ctx, detectors)?; let chunks = chunk_task(ctx, chunker_ids, text_with_offsets).await?; @@ -546,6 +545,7 @@ async fn detection_task( chunks: HashMap>, headers: HeaderMap, ) -> Result, Error> { + debug!(detectors = ?detectors.keys(), "handling detection task"); // Spawn tasks for each detector let tasks = detectors .iter() @@ -595,6 +595,7 @@ async fn chunk_task( chunker_ids: Vec, text_with_offsets: Vec<(usize, String)>, ) -> Result>, Error> { + debug!(?chunker_ids, "handling chunk task"); // Spawn tasks for each chunker let tasks = chunker_ids .into_iter() @@ -612,7 +613,7 @@ async fn chunk_task( } /// Sends a request to a detector service and applies threshold. -#[instrument(skip_all)] +#[instrument(skip_all, fields(detector_id))] pub async fn detect( ctx: Arc, detector_id: String, @@ -629,7 +630,7 @@ pub async fn detect( Vec::default() } else { let request = ContentAnalysisRequest::new(contents, detector_params); - debug!(%detector_id, ?request, "sending detector request"); + debug!(?request, "sending detector request"); let client = ctx .clients .get_as::(&detector_id) @@ -638,14 +639,14 @@ pub async fn detect( .text_contents(&detector_id, request, headers) .await .map_err(|error| { - debug!(%detector_id, ?error, "error received from detector"); + debug!(?error, "error received from detector"); Error::DetectorRequestFailed { id: detector_id.clone(), error, } })? }; - debug!(%detector_id, ?response, "received detector response"); + debug!(?response, "received detector response"); if chunks.len() != response.len() { return Err(Error::Other(format!( "Detector {detector_id} did not return expected number of responses" @@ -671,7 +672,7 @@ pub async fn detect( /// Sends a request to a detector service and applies threshold. /// TODO: Cleanup by removing duplicate code and merging it with above `detect` function -#[instrument(skip_all)] +#[instrument(skip_all, fields(detector_id))] pub async fn detect_content( ctx: Arc, detector_id: String, @@ -688,7 +689,11 @@ pub async fn detect_content( Vec::default() } else { let request = ContentAnalysisRequest::new(contents, detector_params); - debug!(%detector_id, ?request, "sending detector request"); + debug!( + ?request, + threshold, + "sending detector request" + ); let client = ctx .clients .get_as::(&detector_id) @@ -697,7 +702,7 @@ pub async fn detect_content( .text_contents(&detector_id, request, headers) .await .map_err(|error| { - debug!(%detector_id, ?error, "error received from detector"); + debug!(?error, "error received from detector"); Error::DetectorRequestFailed { id: detector_id.clone(), error, @@ -728,6 +733,7 @@ pub async fn detect_content( } /// Calls a detector that implements the /api/v1/text/generation endpoint +#[instrument(skip_all, fields(detector_id))] pub async fn detect_for_generation( ctx: Arc, detector_id: String, @@ -748,7 +754,11 @@ pub async fn detect_for_generation( ); let request = GenerationDetectionRequest::new(prompt.clone(), generated_text.clone(), detector_params); - debug!(%detector_id, ?request, "sending generation detector request"); + debug!( + threshold, + ?request, + "sending generation detector request" + ); let client = ctx .clients .get_as::(&detector_id) @@ -766,7 +776,7 @@ pub async fn detect_for_generation( id: detector_id.clone(), error, })?; - debug!(%detector_id, ?response, "received generation detector response"); + debug!(?response, "received generation detector response"); Ok::, Error>(response) } @@ -812,6 +822,7 @@ pub async fn detect_for_chat( } /// Calls a detector that implements the /api/v1/text/doc endpoint +#[instrument(skip_all, fields(detector_id))] pub async fn detect_for_context( ctx: Arc, detector_id: String, @@ -831,8 +842,14 @@ pub async fn detect_for_context( .default_threshold, ), ); - let request = ContextDocsDetectionRequest::new(content, context_type, context, detector_params); - debug!(%detector_id, ?request, "sending context detector request"); + let request = + ContextDocsDetectionRequest::new(content, context_type, context, detector_params.clone()); + debug!( + ?request, + threshold, + ?detector_params, + "sending context detector request" + ); let client = ctx .clients .get_as::(&detector_id) @@ -855,7 +872,7 @@ pub async fn detect_for_context( } /// Sends request to chunker service. -#[instrument(skip_all)] +#[instrument(skip_all, fields(chunker_id))] pub async fn chunk( ctx: &Arc, chunker_id: String, @@ -863,7 +880,7 @@ pub async fn chunk( text: String, ) -> Result, Error> { let request = chunkers::ChunkerTokenizationTaskRequest { text }; - debug!(%chunker_id, ?request, "sending chunker request"); + debug!(?request, offset, "sending chunk request"); let response = if chunker_id == DEFAULT_CHUNKER_ID { tokenize_whole_doc(request) } else { @@ -877,7 +894,7 @@ pub async fn chunk( })? }; - debug!(%chunker_id, ?response, "received chunker response"); + debug!(?response, "received chunker response"); Ok(response .results .into_iter() @@ -889,11 +906,13 @@ pub async fn chunk( } /// Sends parallel requests to a chunker service. +#[instrument(skip_all, fields(chunker_id))] pub async fn chunk_parallel( ctx: &Arc, chunker_id: String, text_with_offsets: Vec<(usize, String)>, ) -> Result<(String, Vec), Error> { + debug!("sending parallel chunk request"); let chunks = stream::iter(text_with_offsets) .map(|(offset, text)| { let ctx = ctx.clone(); @@ -915,12 +934,14 @@ pub async fn chunk_parallel( } /// Sends tokenize request to a generation service. +#[instrument(skip_all, fields(model_id))] pub async fn tokenize( ctx: &Arc, model_id: String, text: String, headers: HeaderMap, ) -> Result<(u32, Vec), Error> { + debug!("sending tokenize request"); let client = ctx .clients .get_as::("generation") @@ -935,6 +956,7 @@ pub async fn tokenize( } /// Sends generate request to a generation service. +#[instrument(skip_all, fields(model_id))] async fn generate( ctx: &Arc, model_id: String, @@ -942,6 +964,7 @@ async fn generate( params: Option, headers: HeaderMap, ) -> Result { + debug!("sending generate request"); let client = ctx .clients .get_as::("generation") diff --git a/src/server.rs b/src/server.rs index ed921f9e..64210149 100644 --- a/src/server.rs +++ b/src/server.rs @@ -40,13 +40,14 @@ use axum_extra::extract::WithRejection; use futures::{stream, Stream, StreamExt}; use hyper::body::Incoming; use hyper_util::rt::{TokioExecutor, TokioIo}; +use opentelemetry::trace::TraceContextExt; use rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig}; use tokio::{net::TcpListener, signal}; use tokio_rustls::TlsAcceptor; use tower_http::trace::TraceLayer; use tower_service::Service; -use tracing::{debug, error, info, instrument, warn}; -use uuid::Uuid; +use tracing::{debug, error, info, instrument, warn, Span}; +use tracing_opentelemetry::OpenTelemetrySpanExt; use webpki::types::{CertificateDer, PrivateKeyDer}; use crate::{ @@ -318,10 +319,11 @@ async fn classification_with_gen( headers: HeaderMap, WithRejection(Json(request), _): WithRejection, Error>, ) -> Result { - let request_id = Uuid::new_v4(); + let trace_id = Span::current().context().span().span_context().trace_id(); + info!(?trace_id, "handling request"); request.validate()?; let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - let task = ClassificationWithGenTask::new(request_id, request, headers); + let task = ClassificationWithGenTask::new(trace_id, request, headers); match state .orchestrator .handle_classification_with_gen(task) @@ -341,10 +343,11 @@ async fn generation_with_detection( Error, >, ) -> Result { - let request_id = Uuid::new_v4(); + let trace_id = Span::current().context().span().span_context().trace_id(); + info!(?trace_id, "handling request"); request.validate()?; let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - let task = GenerationWithDetectionTask::new(request_id, request, headers); + let task = GenerationWithDetectionTask::new(trace_id, request, headers); match state .orchestrator .handle_generation_with_detection(task) @@ -361,7 +364,8 @@ async fn stream_classification_with_gen( headers: HeaderMap, WithRejection(Json(request), _): WithRejection, Error>, ) -> Sse>> { - let request_id = Uuid::new_v4(); + let trace_id = Span::current().context().span().span_context().trace_id(); + info!(?trace_id, "handling request"); if let Err(error) = request.validate() { // Request validation failed, return stream with single error SSE event let error: Error = error.into(); @@ -374,7 +378,7 @@ async fn stream_classification_with_gen( ); } let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - let task = StreamingClassificationWithGenTask::new(request_id, request, headers); + let task = StreamingClassificationWithGenTask::new(trace_id, request, headers); let response_stream = state .orchestrator .handle_streaming_classification_with_gen(task) @@ -404,10 +408,11 @@ async fn detection_content( headers: HeaderMap, Json(request): Json, ) -> Result { - let request_id = Uuid::new_v4(); + let trace_id = Span::current().context().span().span_context().trace_id(); + info!(?trace_id, "handling request"); request.validate()?; let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - let task = TextContentDetectionTask::new(request_id, request, headers); + let task = TextContentDetectionTask::new(trace_id, request, headers); match state.orchestrator.handle_text_content_detection(task).await { Ok(response) => Ok(Json(response).into_response()), Err(error) => Err(error.into()), @@ -420,10 +425,11 @@ async fn detect_context_documents( headers: HeaderMap, WithRejection(Json(request), _): WithRejection, Error>, ) -> Result { - let request_id = Uuid::new_v4(); + let trace_id = Span::current().context().span().span_context().trace_id(); + info!(?trace_id, "handling request"); request.validate()?; let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - let task = ContextDocsDetectionTask::new(request_id, request, headers); + let task = ContextDocsDetectionTask::new(trace_id, request, headers); match state .orchestrator .handle_context_documents_detection(task) @@ -459,10 +465,11 @@ async fn detect_generated( Error, >, ) -> Result { - let request_id = Uuid::new_v4(); + let trace_id = Span::current().context().span().span_context().trace_id(); + info!(?trace_id, "handling request"); request.validate()?; let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - let task = DetectionOnGenerationTask::new(request_id, request, headers); + let task = DetectionOnGenerationTask::new(trace_id, request, headers); match state .orchestrator .handle_generated_text_detection(task) diff --git a/src/tracing_utils.rs b/src/tracing_utils.rs index b2a18ab1..12ab269d 100644 --- a/src/tracing_utils.rs +++ b/src/tracing_utils.rs @@ -50,12 +50,21 @@ pub enum TracingError { MetricsError(#[from] MetricsError), } +fn service_config(tracing_config: TracingConfig) -> Config { + Config::default() + .with_resource(Resource::new(vec![KeyValue::new( + "service.name", + tracing_config.service_name, + )])) + .with_sampler(Sampler::AlwaysOn) +} + /// Initializes an OpenTelemetry tracer provider with an OTLP export pipeline based on the /// provided config. fn init_tracer_provider( - otlp_export_config: TracingConfig, + tracing_config: TracingConfig, ) -> Result, TracingError> { - if let Some((protocol, endpoint)) = otlp_export_config.traces { + if let Some((protocol, endpoint)) = tracing_config.clone().traces { Ok(Some( match protocol { OtlpProtocol::Grpc => opentelemetry_otlp::new_pipeline().tracing().with_exporter( @@ -72,16 +81,17 @@ fn init_tracer_provider( .with_timeout(Duration::from_secs(3)), ), } - .with_trace_config( - Config::default() - .with_resource(Resource::new(vec![KeyValue::new( - "service.name", - otlp_export_config.service_name, - )])) - .with_sampler(Sampler::AlwaysOn), - ) + .with_trace_config(service_config(tracing_config)) .install_batch(runtime::Tokio)?, )) + } else if !tracing_config.quiet { + // We still need a tracing provider as long as we are logging in order to enable any + // trace-sensitive logs, such as any mentions of a request's trace_id. + Ok(Some( + opentelemetry_sdk::trace::TracerProvider::builder() + .with_config(service_config(tracing_config)) + .build(), + )) } else { Ok(None) } @@ -90,9 +100,9 @@ fn init_tracer_provider( /// Initializes an OpenTelemetry meter provider with an OTLP export pipeline based on the /// provided config. fn init_meter_provider( - otlp_export_config: TracingConfig, + tracing_config: TracingConfig, ) -> Result, TracingError> { - if let Some((protocol, endpoint)) = otlp_export_config.metrics { + if let Some((protocol, endpoint)) = tracing_config.metrics { Ok(Some( match protocol { OtlpProtocol::Grpc => opentelemetry_otlp::new_pipeline() @@ -113,7 +123,7 @@ fn init_meter_provider( } .with_resource(Resource::new(vec![KeyValue::new( "service.name", - otlp_export_config.service_name, + tracing_config.service_name, )])) .with_timeout(Duration::from_secs(10)) .with_period(Duration::from_secs(3)) @@ -237,11 +247,18 @@ pub fn incoming_request_span(request: &Request) -> Span { pub fn on_incoming_request(request: &Request, span: &Span) { let _guard = span.enter(); + let trace_id = Span::current() + .context() + .span() + .span_context() + .trace_id() + .to_string(); + println!("trace: {}", trace_id); info!( "incoming request to {} {} with trace_id {}", request.method(), request.uri().path(), - span.context().span().span_context().trace_id().to_string() + trace_id, ); info!( monotonic_counter.incoming_request_count = 1, From 6a96c9a2466e638b189d51426f25149ffae15e67 Mon Sep 17 00:00:00 2001 From: Paul Scoropan <1paulscoropan@gmail.com> Date: Wed, 16 Oct 2024 16:09:18 -0400 Subject: [PATCH 38/50] traceparent header creation Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> --- src/clients.rs | 8 ++++++ src/clients/detector/text_contents.rs | 2 ++ src/clients/detector/text_context_doc.rs | 2 ++ src/clients/detector/text_generation.rs | 2 ++ src/clients/nlp.rs | 25 +++++++++++------- src/clients/openai.rs | 7 ++--- src/clients/tgis.rs | 17 +++++++----- src/tracing_utils.rs | 33 +++++++++++++++++++++++- 8 files changed, 77 insertions(+), 19 deletions(-) diff --git a/src/clients.rs b/src/clients.rs index 992da6a8..bc2ab94f 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -24,9 +24,12 @@ use std::{ }; use async_trait::async_trait; +use axum::http::{Extensions, HeaderMap}; use futures::Stream; use ginepro::LoadBalancedChannel; use tokio::{fs::File, io::AsyncReadExt}; +use tonic::metadata::MetadataMap; +use tonic::Request; use tracing::{debug, instrument}; use url::Url; @@ -343,6 +346,11 @@ pub fn is_valid_hostname(hostname: &str) -> bool { || hostname.len() > 253) } +fn grpc_request_with_headers(request: T, headers: HeaderMap) -> Request { + let metadata = MetadataMap::from_headers(headers); + Request::from_parts(metadata, Extensions::new(), request) +} + #[cfg(test)] mod tests { use errors::grpc_to_http_code; diff --git a/src/clients/detector/text_contents.rs b/src/clients/detector/text_contents.rs index d7df440e..a92eaf70 100644 --- a/src/clients/detector/text_contents.rs +++ b/src/clients/detector/text_contents.rs @@ -26,6 +26,7 @@ use crate::{ config::ServiceConfig, health::HealthCheckResult, models::DetectorParams, + tracing_utils::with_traceparent_header, }; #[cfg_attr(test, faux::create)] @@ -63,6 +64,7 @@ impl TextContentsDetectorClient { .join("/api/v1/text/contents") .unwrap(); info!(?url, ?request, "sending client request"); + let headers = with_traceparent_header(headers); let response = self .client .post(url) diff --git a/src/clients/detector/text_context_doc.rs b/src/clients/detector/text_context_doc.rs index e73a985f..c501fd3e 100644 --- a/src/clients/detector/text_context_doc.rs +++ b/src/clients/detector/text_context_doc.rs @@ -26,6 +26,7 @@ use crate::{ config::ServiceConfig, health::HealthCheckResult, models::{DetectionResult, DetectorParams}, + tracing_utils::with_traceparent_header, }; #[cfg_attr(test, faux::create)] @@ -63,6 +64,7 @@ impl TextContextDocDetectorClient { .join("/api/v1/text/context/doc") .unwrap(); info!(?url, ?request, "sending client request"); + let headers = with_traceparent_header(headers); let response = self .client .post(url) diff --git a/src/clients/detector/text_generation.rs b/src/clients/detector/text_generation.rs index 2486cea0..874bed54 100644 --- a/src/clients/detector/text_generation.rs +++ b/src/clients/detector/text_generation.rs @@ -21,6 +21,7 @@ use serde::Serialize; use tracing::{info, instrument}; use super::{DetectorError, DEFAULT_PORT, DETECTOR_ID_HEADER_NAME}; +use crate::tracing_utils::with_traceparent_header; use crate::{ clients::{create_http_client, Client, Error, HttpClient}, config::ServiceConfig, @@ -63,6 +64,7 @@ impl TextGenerationDetectorClient { .join("/api/v1/text/generation") .unwrap(); info!(?url, ?request, "sending client request"); + let headers = with_traceparent_header(headers); let response = self .client .post(url) diff --git a/src/clients/nlp.rs b/src/clients/nlp.rs index bbceb29d..4affb5a3 100644 --- a/src/clients/nlp.rs +++ b/src/clients/nlp.rs @@ -15,7 +15,10 @@ */ -use super::{create_grpc_client, errors::grpc_to_http_code, BoxStream, Client, Error}; +use super::{ + create_grpc_client, errors::grpc_to_http_code, grpc_request_with_headers, BoxStream, Client, + Error, +}; use crate::{ config::ServiceConfig, health::{HealthCheckResult, HealthStatus}, @@ -30,12 +33,13 @@ use crate::{ }, grpc::health::v1::{health_client::HealthClient, HealthCheckRequest}, }, + tracing_utils::with_traceparent_header, }; use async_trait::async_trait; -use axum::http::{Extensions, HeaderMap}; +use axum::http::HeaderMap; use futures::{StreamExt, TryStreamExt}; use ginepro::LoadBalancedChannel; -use tonic::{metadata::MetadataMap, Code, Request}; +use tonic::{Code, Request}; use tracing::{info, instrument}; const DEFAULT_PORT: u16 = 8085; @@ -59,7 +63,7 @@ impl NlpClient { } } - #[instrument(skip_all, fields(model_id, ?headers))] + #[instrument(skip_all, fields(model_id))] pub async fn tokenization_task_predict( &self, model_id: &str, @@ -67,6 +71,7 @@ impl NlpClient { headers: HeaderMap, ) -> Result { let mut client = self.client.clone(); + let headers = with_traceparent_header(headers); let request = request_with_model_id(request, model_id, headers); info!(?request, "sending request to NLP gRPC service"); Ok(client @@ -75,7 +80,7 @@ impl NlpClient { .into_inner()) } - #[instrument(skip_all, fields(model_id, ?headers))] + #[instrument(skip_all, fields(model_id))] pub async fn token_classification_task_predict( &self, model_id: &str, @@ -83,6 +88,7 @@ impl NlpClient { headers: HeaderMap, ) -> Result { let mut client = self.client.clone(); + let headers = with_traceparent_header(headers); let request = request_with_model_id(request, model_id, headers); info!(?request, "sending request to NLP gRPC service"); Ok(client @@ -91,7 +97,7 @@ impl NlpClient { .into_inner()) } - #[instrument(skip_all, fields(model_id, ?headers))] + #[instrument(skip_all, fields(model_id))] pub async fn text_generation_task_predict( &self, model_id: &str, @@ -99,6 +105,7 @@ impl NlpClient { headers: HeaderMap, ) -> Result { let mut client = self.client.clone(); + let headers = with_traceparent_header(headers); let request = request_with_model_id(request, model_id, headers); info!(?request, "sending request to NLP gRPC service"); Ok(client @@ -107,7 +114,7 @@ impl NlpClient { .into_inner()) } - #[instrument(skip_all, fields(model_id, ?headers))] + #[instrument(skip_all, fields(model_id))] pub async fn server_streaming_text_generation_task_predict( &self, model_id: &str, @@ -115,6 +122,7 @@ impl NlpClient { headers: HeaderMap, ) -> Result>, Error> { let mut client = self.client.clone(); + let headers = with_traceparent_header(headers); let request = request_with_model_id(request, model_id, headers); info!(?request, "sending stream request to NLP gRPC service"); let response_stream = client @@ -160,8 +168,7 @@ impl Client for NlpClient { } fn request_with_model_id(request: T, model_id: &str, headers: HeaderMap) -> Request { - let metadata = MetadataMap::from_headers(headers); - let mut request = Request::from_parts(metadata, Extensions::new(), request); + let mut request = grpc_request_with_headers(request, headers); request .metadata_mut() .insert(MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); diff --git a/src/clients/openai.rs b/src/clients/openai.rs index 66eea424..bd86d1b7 100644 --- a/src/clients/openai.rs +++ b/src/clients/openai.rs @@ -17,12 +17,12 @@ use std::collections::HashMap; +use super::{create_http_client, Client, Error, HttpClient}; +use crate::{config::ServiceConfig, health::HealthCheckResult}; use async_trait::async_trait; use hyper::StatusCode; use serde::{Deserialize, Serialize}; - -use super::{create_http_client, Client, Error, HttpClient}; -use crate::{config::ServiceConfig, health::HealthCheckResult}; +use tracing::instrument; const DEFAULT_PORT: u16 = 8080; @@ -48,6 +48,7 @@ impl OpenAiClient { } } + #[instrument(skip_all, fields(request.model))] pub async fn chat_completions( &self, request: ChatCompletionRequest, diff --git a/src/clients/tgis.rs b/src/clients/tgis.rs index cf07493c..065fc5f6 100644 --- a/src/clients/tgis.rs +++ b/src/clients/tgis.rs @@ -22,7 +22,10 @@ use ginepro::LoadBalancedChannel; use tonic::Code; use tracing::{info, instrument}; -use super::{create_grpc_client, errors::grpc_to_http_code, BoxStream, Client, Error}; +use super::{ + create_grpc_client, errors::grpc_to_http_code, grpc_request_with_headers, BoxStream, Client, + Error, +}; use crate::{ config::ServiceConfig, health::{HealthCheckResult, HealthStatus}, @@ -48,23 +51,25 @@ impl TgisClient { Self { client } } - #[instrument(skip_all, fields(model_id = request.model_id, headers = ?_headers))] + #[instrument(skip_all, fields(model_id = request.model_id))] pub async fn generate( &self, request: BatchedGenerationRequest, - _headers: HeaderMap, + headers: HeaderMap, ) -> Result { + let request = grpc_request_with_headers(request, headers); info!(?request, "sending request to TGIS gRPC service"); let mut client = self.client.clone(); Ok(client.generate(request).await?.into_inner()) } - #[instrument(skip_all, fields(model_id = request.model_id, headers = ?_headers))] + #[instrument(skip_all, fields(model_id = request.model_id))] pub async fn generate_stream( &self, request: SingleGenerationRequest, - _headers: HeaderMap, + headers: HeaderMap, ) -> Result>, Error> { + let request = grpc_request_with_headers(request, headers); info!(?request, "sending request to TGIS gRPC service"); let mut client = self.client.clone(); let response_stream = client @@ -76,7 +81,7 @@ impl TgisClient { Ok(response_stream) } - #[instrument(skip_all, fields(model_id = request.model_id, headers = ?_headers))] + #[instrument(skip_all, fields(model_id = request.model_id))] pub async fn tokenize( &self, request: BatchedTokenizeRequest, diff --git a/src/tracing_utils.rs b/src/tracing_utils.rs index 12ab269d..6bd294dc 100644 --- a/src/tracing_utils.rs +++ b/src/tracing_utils.rs @@ -36,12 +36,16 @@ use opentelemetry_sdk::{ trace::{Config, Sampler}, Resource, }; -use tracing::{error, info, info_span, Span}; +use tracing::{error, info, info_span, warn, Span}; use tracing_opentelemetry::{MetricsLayer, OpenTelemetrySpanExt}; use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Layer}; use crate::args::{LogFormat, OtlpProtocol, TracingConfig}; +const TRACEPARENT_HEADER_NAME: &str = "traceparent"; +const TRACEPARENT_VERSION: &str = "00"; +const TRACEPARENT_TRACE_FLAGS: &str = "01"; + #[derive(Debug, thiserror::Error)] pub enum TracingError { #[error("Error from tracing provider: {0}")] @@ -337,3 +341,30 @@ pub fn on_outgoing_eos(trailers: Option<&HeaderMap>, stream_duration: Duration, ); info!(monotonic_histogram.service_stream_response_duration = stream_duration.as_millis()); } + +pub fn with_traceparent_header(headers: HeaderMap) -> HeaderMap { + let mut headers = headers.clone(); + if let Some(traceparent) = headers.get(TRACEPARENT_HEADER_NAME) { + warn!( + "traceparent header already set to {}", + traceparent.to_str().unwrap_or_default() // avoiding panics for tracing logic + ) + } + headers.insert( + TRACEPARENT_HEADER_NAME, + get_current_traceparent().parse().unwrap(), + ); + headers +} + +fn get_current_traceparent() -> String { + let ctx = Span::current().context(); + let span_ref = ctx.span(); + let ctx = span_ref.span_context().clone(); + let version = TRACEPARENT_VERSION.to_string(); + let trace_id = ctx.trace_id().to_string(); + let span_id = ctx.span_id().to_string(); + let trace_flags = TRACEPARENT_TRACE_FLAGS; + + version + "-" + &trace_id + "-" + &span_id + "-" + trace_flags +} From f6740d7c5a0f47692685e59a044d45dca76af70c Mon Sep 17 00:00:00 2001 From: Paul Scoropan <1paulscoropan@gmail.com> Date: Thu, 17 Oct 2024 13:37:12 -0400 Subject: [PATCH 39/50] revised and finished traceparent propagation Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> --- src/clients.rs | 2 + src/clients/chunker.rs | 32 +++++++++------- src/clients/detector.rs | 34 ++++++++++++++-- src/clients/detector/text_contents.rs | 19 +++------ src/clients/detector/text_context_doc.rs | 19 +++------ src/clients/detector/text_generation.rs | 19 +++------ src/clients/generation.rs | 4 +- src/clients/http.rs | 8 ++++ src/clients/nlp.rs | 46 +++++++++------------- src/clients/openai.rs | 18 +++++++-- src/clients/tgis.rs | 23 ++++++----- src/tracing_utils.rs | 49 ++++++++++++------------ 12 files changed, 148 insertions(+), 125 deletions(-) diff --git a/src/clients.rs b/src/clients.rs index bc2ab94f..877b716e 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -36,6 +36,7 @@ use url::Url; use crate::{ config::{ServiceConfig, Tls}, health::HealthCheckResult, + tracing_utils::with_traceparent_header, }; pub mod errors; @@ -347,6 +348,7 @@ pub fn is_valid_hostname(hostname: &str) -> bool { } fn grpc_request_with_headers(request: T, headers: HeaderMap) -> Request { + let headers = with_traceparent_header(headers); let metadata = MetadataMap::from_headers(headers); Request::from_parts(metadata, Extensions::new(), request) } diff --git a/src/clients/chunker.rs b/src/clients/chunker.rs index 1e877015..9c139b45 100644 --- a/src/clients/chunker.rs +++ b/src/clients/chunker.rs @@ -18,12 +18,16 @@ use std::pin::Pin; use async_trait::async_trait; +use axum::http::HeaderMap; use futures::{Future, Stream, StreamExt, TryStreamExt}; use ginepro::LoadBalancedChannel; use tonic::{Code, Request, Response, Status, Streaming}; use tracing::{info, instrument}; -use super::{create_grpc_client, errors::grpc_to_http_code, BoxStream, Client, Error}; +use super::{ + create_grpc_client, errors::grpc_to_http_code, grpc_request_with_headers, BoxStream, Client, + Error, +}; use crate::{ config::ServiceConfig, health::{HealthCheckResult, HealthStatus}, @@ -35,6 +39,7 @@ use crate::{ caikit_data_model::nlp::{ChunkerTokenizationStreamResult, Token, TokenizationResults}, grpc::health::v1::{health_client::HealthClient, HealthCheckRequest}, }, + tracing_utils::trace_context_from_grpc_response, }; const DEFAULT_PORT: u16 = 8085; @@ -70,12 +75,11 @@ impl ChunkerClient { request: ChunkerTokenizationTaskRequest, ) -> Result { let mut client = self.client.clone(); - let request = request_with_model_id(request, model_id); + let request = request_with_headers(request, model_id); info!(?request, "sending client request"); - Ok(client - .chunker_tokenization_task_predict(request) - .await? - .into_inner()) + let response = client.chunker_tokenization_task_predict(request).await?; + trace_context_from_grpc_response(&response); + Ok(response.into_inner()) } #[instrument(skip_all, fields(model_id))] @@ -86,17 +90,17 @@ impl ChunkerClient { ) -> Result>, Error> { info!("sending client stream request"); let mut client = self.client.clone(); - let request = request_with_model_id(request_stream, model_id); + let request = request_with_headers(request_stream, model_id); // NOTE: this is an ugly workaround to avoid bogus higher-ranked lifetime errors. // https://github.com/rust-lang/rust/issues/110338 let response_stream_fut: Pin + Send>> = Box::pin(client.bidi_streaming_chunker_tokenization_task_predict(request)); - let response_stream = response_stream_fut - .await? + let response_stream = response_stream_fut.await?; + trace_context_from_grpc_response(&response_stream); + Ok(response_stream .into_inner() .map_err(Into::into) - .boxed(); - Ok(response_stream) + .boxed()) } } @@ -132,8 +136,8 @@ impl Client for ChunkerClient { } } -fn request_with_model_id(request: T, model_id: &str) -> Request { - let mut request = Request::new(request); +fn request_with_headers(request: T, model_id: &str) -> Request { + let mut request = grpc_request_with_headers(request, HeaderMap::new()); request .metadata_mut() .insert(MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); @@ -141,6 +145,7 @@ fn request_with_model_id(request: T, model_id: &str) -> Request { } /// Unary tokenization result of the entire doc +#[instrument(skip_all)] pub fn tokenize_whole_doc(request: ChunkerTokenizationTaskRequest) -> TokenizationResults { let codepoint_count = request.text.chars().count() as i64; TokenizationResults { @@ -154,6 +159,7 @@ pub fn tokenize_whole_doc(request: ChunkerTokenizationTaskRequest) -> Tokenizati } /// Streaming tokenization result for the entire doc stream +#[instrument(skip_all)] pub async fn tokenize_whole_doc_stream( request: impl Stream, ) -> Result { diff --git a/src/clients/detector.rs b/src/clients/detector.rs index 80e0cb70..225a4ec9 100644 --- a/src/clients/detector.rs +++ b/src/clients/detector.rs @@ -15,6 +15,15 @@ */ +use std::fmt::Debug; + +use axum::http::HeaderMap; +use hyper::StatusCode; +use reqwest::Response; +use serde::{Deserialize, Serialize}; +use tracing::info; +use url::Url; + pub mod text_contents; pub use text_contents::*; pub mod text_chat; @@ -22,11 +31,10 @@ pub use text_chat::*; pub mod text_context_doc; pub use text_context_doc::*; pub mod text_generation; -use hyper::StatusCode; -use serde::Deserialize; pub use text_generation::*; -use super::Error; +use super::{Error, HttpClient}; +use crate::tracing_utils::{trace_context_from_http_response, with_traceparent_header}; const DEFAULT_PORT: u16 = 8080; const DETECTOR_ID_HEADER_NAME: &str = "detector-id"; @@ -45,3 +53,23 @@ impl From for Error { } } } + +pub async fn post_with_headers( + client: HttpClient, + url: Url, + request: T, + headers: HeaderMap, + model_id: &str, +) -> Result { + let headers = with_traceparent_header(headers); + info!(?url, ?headers, ?request, "sending client request"); + let response = client + .post(url) + .headers(headers) + .header(DETECTOR_ID_HEADER_NAME, model_id) + .json(&request) + .send() + .await?; + trace_context_from_http_response(&response); + Ok(response) +} diff --git a/src/clients/detector/text_contents.rs b/src/clients/detector/text_contents.rs index a92eaf70..d1015b1d 100644 --- a/src/clients/detector/text_contents.rs +++ b/src/clients/detector/text_contents.rs @@ -18,15 +18,14 @@ use async_trait::async_trait; use hyper::{HeaderMap, StatusCode}; use serde::{Deserialize, Serialize}; -use tracing::{info, instrument}; +use tracing::instrument; -use super::{DetectorError, DEFAULT_PORT, DETECTOR_ID_HEADER_NAME}; +use super::{post_with_headers, DetectorError, DEFAULT_PORT}; use crate::{ clients::{create_http_client, Client, Error, HttpClient}, config::ServiceConfig, health::HealthCheckResult, models::DetectorParams, - tracing_utils::with_traceparent_header, }; #[cfg_attr(test, faux::create)] @@ -51,7 +50,7 @@ impl TextContentsDetectorClient { } } - #[instrument(skip_all, fields(model_id, ?headers))] + #[instrument(skip_all, fields(model_id))] pub async fn text_contents( &self, model_id: &str, @@ -63,16 +62,8 @@ impl TextContentsDetectorClient { .base_url() .join("/api/v1/text/contents") .unwrap(); - info!(?url, ?request, "sending client request"); - let headers = with_traceparent_header(headers); - let response = self - .client - .post(url) - .headers(headers) - .header(DETECTOR_ID_HEADER_NAME, model_id) - .json(&request) - .send() - .await?; + let response = + post_with_headers(self.client.clone(), url, request, headers, model_id).await?; if response.status() == StatusCode::OK { Ok(response.json().await?) } else { diff --git a/src/clients/detector/text_context_doc.rs b/src/clients/detector/text_context_doc.rs index c501fd3e..e56c1619 100644 --- a/src/clients/detector/text_context_doc.rs +++ b/src/clients/detector/text_context_doc.rs @@ -18,15 +18,14 @@ use async_trait::async_trait; use hyper::{HeaderMap, StatusCode}; use serde::{Deserialize, Serialize}; -use tracing::{info, instrument}; +use tracing::instrument; -use super::{DetectorError, DEFAULT_PORT, DETECTOR_ID_HEADER_NAME}; +use super::{post_with_headers, DetectorError, DEFAULT_PORT}; use crate::{ clients::{create_http_client, Client, Error, HttpClient}, config::ServiceConfig, health::HealthCheckResult, models::{DetectionResult, DetectorParams}, - tracing_utils::with_traceparent_header, }; #[cfg_attr(test, faux::create)] @@ -51,7 +50,7 @@ impl TextContextDocDetectorClient { } } - #[instrument(skip_all, fields(model_id, ?headers))] + #[instrument(skip_all, fields(model_id))] pub async fn text_context_doc( &self, model_id: &str, @@ -63,16 +62,8 @@ impl TextContextDocDetectorClient { .base_url() .join("/api/v1/text/context/doc") .unwrap(); - info!(?url, ?request, "sending client request"); - let headers = with_traceparent_header(headers); - let response = self - .client - .post(url) - .headers(headers) - .header(DETECTOR_ID_HEADER_NAME, model_id) - .json(&request) - .send() - .await?; + let response = + post_with_headers(self.client.clone(), url, request, headers, model_id).await?; if response.status() == StatusCode::OK { Ok(response.json().await?) } else { diff --git a/src/clients/detector/text_generation.rs b/src/clients/detector/text_generation.rs index 874bed54..5a63d6c9 100644 --- a/src/clients/detector/text_generation.rs +++ b/src/clients/detector/text_generation.rs @@ -18,10 +18,9 @@ use async_trait::async_trait; use hyper::{HeaderMap, StatusCode}; use serde::Serialize; -use tracing::{info, instrument}; +use tracing::instrument; -use super::{DetectorError, DEFAULT_PORT, DETECTOR_ID_HEADER_NAME}; -use crate::tracing_utils::with_traceparent_header; +use super::{post_with_headers, DetectorError, DEFAULT_PORT}; use crate::{ clients::{create_http_client, Client, Error, HttpClient}, config::ServiceConfig, @@ -51,7 +50,7 @@ impl TextGenerationDetectorClient { } } - #[instrument(skip_all, fields(model_id, ?headers))] + #[instrument(skip_all, fields(model_id))] pub async fn text_generation( &self, model_id: &str, @@ -63,16 +62,8 @@ impl TextGenerationDetectorClient { .base_url() .join("/api/v1/text/generation") .unwrap(); - info!(?url, ?request, "sending client request"); - let headers = with_traceparent_header(headers); - let response = self - .client - .post(url) - .headers(headers) - .header(DETECTOR_ID_HEADER_NAME, model_id) - .json(&request) - .send() - .await?; + let response = + post_with_headers(self.client.clone(), url, request, headers, model_id).await?; if response.status() == StatusCode::OK { Ok(response.json().await?) } else { diff --git a/src/clients/generation.rs b/src/clients/generation.rs index 1b514dac..afd54e59 100644 --- a/src/clients/generation.rs +++ b/src/clients/generation.rs @@ -63,7 +63,7 @@ impl GenerationClient { Self(None) } - #[instrument(skip_all, fields(model_id, ?headers))] + #[instrument(skip_all, fields(model_id))] pub async fn tokenize( &self, model_id: String, @@ -103,6 +103,7 @@ impl GenerationClient { } } + #[instrument(skip_all, fields(model_id))] pub async fn generate( &self, model_id: String, @@ -167,6 +168,7 @@ impl GenerationClient { } } + #[instrument(skip_all, fields(model_id))] pub async fn generate_stream( &self, model_id: String, diff --git a/src/clients/http.rs b/src/clients/http.rs index 862db811..803daf8c 100644 --- a/src/clients/http.rs +++ b/src/clients/http.rs @@ -43,6 +43,14 @@ impl HttpClient { &self.base_url } + pub fn into_inner(self) -> reqwest::Client { + self.client + } + + pub fn inner_as_ref(&self) -> &reqwest::Client { + &self.client + } + /// This is sectioned off to allow for testing. pub(super) async fn http_response_to_health_check_result( res: Result, diff --git a/src/clients/nlp.rs b/src/clients/nlp.rs index 4affb5a3..67bf6cc8 100644 --- a/src/clients/nlp.rs +++ b/src/clients/nlp.rs @@ -33,7 +33,7 @@ use crate::{ }, grpc::health::v1::{health_client::HealthClient, HealthCheckRequest}, }, - tracing_utils::with_traceparent_header, + tracing_utils::trace_context_from_grpc_response, }; use async_trait::async_trait; use axum::http::HeaderMap; @@ -71,13 +71,11 @@ impl NlpClient { headers: HeaderMap, ) -> Result { let mut client = self.client.clone(); - let headers = with_traceparent_header(headers); - let request = request_with_model_id(request, model_id, headers); + let request = request_with_headers(request, model_id, headers); info!(?request, "sending request to NLP gRPC service"); - Ok(client - .tokenization_task_predict(request) - .await? - .into_inner()) + let response = client.tokenization_task_predict(request).await?; + trace_context_from_grpc_response(&response); + Ok(response.into_inner()) } #[instrument(skip_all, fields(model_id))] @@ -88,13 +86,11 @@ impl NlpClient { headers: HeaderMap, ) -> Result { let mut client = self.client.clone(); - let headers = with_traceparent_header(headers); - let request = request_with_model_id(request, model_id, headers); + let request = request_with_headers(request, model_id, headers); info!(?request, "sending request to NLP gRPC service"); - Ok(client - .token_classification_task_predict(request) - .await? - .into_inner()) + let response = client.token_classification_task_predict(request).await?; + trace_context_from_grpc_response(&response); + Ok(response.into_inner()) } #[instrument(skip_all, fields(model_id))] @@ -105,13 +101,11 @@ impl NlpClient { headers: HeaderMap, ) -> Result { let mut client = self.client.clone(); - let headers = with_traceparent_header(headers); - let request = request_with_model_id(request, model_id, headers); + let request = request_with_headers(request, model_id, headers); info!(?request, "sending request to NLP gRPC service"); - Ok(client - .text_generation_task_predict(request) - .await? - .into_inner()) + let response = client.text_generation_task_predict(request).await?; + trace_context_from_grpc_response(&response); + Ok(response.into_inner()) } #[instrument(skip_all, fields(model_id))] @@ -122,15 +116,13 @@ impl NlpClient { headers: HeaderMap, ) -> Result>, Error> { let mut client = self.client.clone(); - let headers = with_traceparent_header(headers); - let request = request_with_model_id(request, model_id, headers); + let request = request_with_headers(request, model_id, headers); info!(?request, "sending stream request to NLP gRPC service"); - let response_stream = client + let response = client .server_streaming_text_generation_task_predict(request) - .await? - .into_inner() - .map_err(Into::into) - .boxed(); + .await?; + trace_context_from_grpc_response(&response); + let response_stream = response.into_inner().map_err(Into::into).boxed(); Ok(response_stream) } } @@ -167,7 +159,7 @@ impl Client for NlpClient { } } -fn request_with_model_id(request: T, model_id: &str, headers: HeaderMap) -> Request { +fn request_with_headers(request: T, model_id: &str, headers: HeaderMap) -> Request { let mut request = grpc_request_with_headers(request, headers); request .metadata_mut() diff --git a/src/clients/openai.rs b/src/clients/openai.rs index bd86d1b7..d17114f0 100644 --- a/src/clients/openai.rs +++ b/src/clients/openai.rs @@ -18,11 +18,13 @@ use std::collections::HashMap; use super::{create_http_client, Client, Error, HttpClient}; -use crate::{config::ServiceConfig, health::HealthCheckResult}; +use crate::{ + config::ServiceConfig, health::HealthCheckResult, tracing_utils::with_traceparent_header, +}; use async_trait::async_trait; -use hyper::StatusCode; +use hyper::{HeaderMap, StatusCode}; use serde::{Deserialize, Serialize}; -use tracing::instrument; +use tracing::{info, instrument}; const DEFAULT_PORT: u16 = 8080; @@ -54,7 +56,15 @@ impl OpenAiClient { request: ChatCompletionRequest, ) -> Result { let url = self.client.base_url().join("/v1/chat/completions").unwrap(); - let response = self.client.post(url).json(&request).send().await?; + let headers = with_traceparent_header(HeaderMap::new()); + info!(?url, ?headers, ?request, "sending client request"); + let response = self + .client + .post(url) + .headers(headers) + .json(&request) + .send() + .await?; match response.status() { StatusCode::OK => Ok(response.json().await?), _ => Err(Error::Http { diff --git a/src/clients/tgis.rs b/src/clients/tgis.rs index 065fc5f6..a11976b1 100644 --- a/src/clients/tgis.rs +++ b/src/clients/tgis.rs @@ -34,6 +34,7 @@ use crate::{ BatchedGenerationResponse, BatchedTokenizeRequest, BatchedTokenizeResponse, GenerationResponse, ModelInfoRequest, ModelInfoResponse, SingleGenerationRequest, }, + tracing_utils::trace_context_from_grpc_response, }; const DEFAULT_PORT: u16 = 8033; @@ -72,30 +73,32 @@ impl TgisClient { let request = grpc_request_with_headers(request, headers); info!(?request, "sending request to TGIS gRPC service"); let mut client = self.client.clone(); - let response_stream = client - .generate_stream(request) - .await? - .into_inner() - .map_err(Into::into) - .boxed(); - Ok(response_stream) + let response = client.generate_stream(request).await?; + trace_context_from_grpc_response(&response); + Ok(response.into_inner().map_err(Into::into).boxed()) } #[instrument(skip_all, fields(model_id = request.model_id))] pub async fn tokenize( &self, request: BatchedTokenizeRequest, - _headers: HeaderMap, + headers: HeaderMap, ) -> Result { info!(?request, "sending request to TGIS gRPC service"); let mut client = self.client.clone(); - Ok(client.tokenize(request).await?.into_inner()) + let request = grpc_request_with_headers(request, headers); + let response = client.tokenize(request).await?; + trace_context_from_grpc_response(&response); + Ok(response.into_inner()) } pub async fn model_info(&self, request: ModelInfoRequest) -> Result { info!(?request, "sending request to TGIS gRPC service"); + let request = grpc_request_with_headers(request, HeaderMap::new()); let mut client = self.client.clone(); - Ok(client.model_info(request).await?.into_inner()) + let response = client.model_info(request).await?; + trace_context_from_grpc_response(&response); + Ok(response.into_inner()) } } diff --git a/src/tracing_utils.rs b/src/tracing_utils.rs index 6bd294dc..b689d716 100644 --- a/src/tracing_utils.rs +++ b/src/tracing_utils.rs @@ -26,7 +26,9 @@ use opentelemetry::{ trace::{TraceContextExt, TraceError, TracerProvider}, KeyValue, }; +use opentelemetry_http::{HeaderExtractor, HeaderInjector}; use opentelemetry_otlp::WithExportConfig; +use opentelemetry_sdk::propagation::TraceContextPropagator; use opentelemetry_sdk::{ metrics::{ reader::{DefaultAggregationSelector, DefaultTemporalitySelector}, @@ -36,16 +38,12 @@ use opentelemetry_sdk::{ trace::{Config, Sampler}, Resource, }; -use tracing::{error, info, info_span, warn, Span}; +use tracing::{error, info, info_span, Span}; use tracing_opentelemetry::{MetricsLayer, OpenTelemetrySpanExt}; use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Layer}; use crate::args::{LogFormat, OtlpProtocol, TracingConfig}; -const TRACEPARENT_HEADER_NAME: &str = "traceparent"; -const TRACEPARENT_VERSION: &str = "00"; -const TRACEPARENT_TRACE_FLAGS: &str = "01"; - #[derive(Debug, thiserror::Error)] pub enum TracingError { #[error("Error from tracing provider: {0}")] @@ -146,6 +144,7 @@ pub fn init_tracing( tracing_config: TracingConfig, ) -> Result Result<(), TracingError>, TracingError> { let mut layers = Vec::new(); + global::set_text_map_propagator(TraceContextPropagator::new()); // TODO: Find a better way to only propagate errors from other crates let filter = EnvFilter::try_from_default_env() @@ -344,27 +343,27 @@ pub fn on_outgoing_eos(trailers: Option<&HeaderMap>, stream_duration: Duration, pub fn with_traceparent_header(headers: HeaderMap) -> HeaderMap { let mut headers = headers.clone(); - if let Some(traceparent) = headers.get(TRACEPARENT_HEADER_NAME) { - warn!( - "traceparent header already set to {}", - traceparent.to_str().unwrap_or_default() // avoiding panics for tracing logic - ) - } - headers.insert( - TRACEPARENT_HEADER_NAME, - get_current_traceparent().parse().unwrap(), - ); + let ctx = Span::current().context(); + global::get_text_map_propagator(|propagator| { + // Injects current `traceparent` (and by default empty `tracestate`) + propagator.inject_context(&ctx, &mut HeaderInjector(&mut headers)) + }); headers } -fn get_current_traceparent() -> String { - let ctx = Span::current().context(); - let span_ref = ctx.span(); - let ctx = span_ref.span_context().clone(); - let version = TRACEPARENT_VERSION.to_string(); - let trace_id = ctx.trace_id().to_string(); - let span_id = ctx.span_id().to_string(); - let trace_flags = TRACEPARENT_TRACE_FLAGS; - - version + "-" + &trace_id + "-" + &span_id + "-" + trace_flags +pub fn trace_context_from_http_response(response: &reqwest::Response) { + let ctx = global::get_text_map_propagator(|propagator| { + // Returns the current context if no `traceparent` is found + propagator.extract(&HeaderExtractor(response.headers())) + }); + Span::current().set_parent(ctx); +} + +pub fn trace_context_from_grpc_response(response: &tonic::Response) { + let ctx = global::get_text_map_propagator(|propagator| { + let metadata = response.metadata().clone(); + // Returns the current context if no `traceparent` is found + propagator.extract(&HeaderExtractor(&metadata.into_headers())) + }); + Span::current().set_parent(ctx); } From 23c439eeebde37e7c3bebd4e25f54469734e4762 Mon Sep 17 00:00:00 2001 From: Paul Scoropan <1paulscoropan@gmail.com> Date: Thu, 17 Oct 2024 13:53:13 -0400 Subject: [PATCH 40/50] some clean up of imports Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> --- src/clients.rs | 12 ++++-------- src/clients/http.rs | 8 -------- src/clients/nlp.rs | 13 +++++++------ src/clients/openai.rs | 9 +++++---- 4 files changed, 16 insertions(+), 26 deletions(-) diff --git a/src/clients.rs b/src/clients.rs index 877b716e..9e7b6bc8 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -28,8 +28,7 @@ use axum::http::{Extensions, HeaderMap}; use futures::Stream; use ginepro::LoadBalancedChannel; use tokio::{fs::File, io::AsyncReadExt}; -use tonic::metadata::MetadataMap; -use tonic::Request; +use tonic::{metadata::MetadataMap, Request}; use tracing::{debug, instrument}; use url::Url; @@ -276,12 +275,9 @@ pub async fn create_grpc_client( .request_timeout .unwrap_or(DEFAULT_REQUEST_TIMEOUT_SEC), ); - let mut builder = LoadBalancedChannel::builder(( - service_config.hostname.clone(), - service_config.port.unwrap_or(default_port), - )) - .connect_timeout(DEFAULT_CONNECT_TIMEOUT) - .timeout(request_timeout); + let mut builder = LoadBalancedChannel::builder((service_config.hostname.clone(), port)) + .connect_timeout(DEFAULT_CONNECT_TIMEOUT) + .timeout(request_timeout); let client_tls_config = if let Some(Tls::Config(tls_config)) = &service_config.tls { let cert_path = tls_config.cert_path.as_ref().unwrap().as_path(); diff --git a/src/clients/http.rs b/src/clients/http.rs index 803daf8c..862db811 100644 --- a/src/clients/http.rs +++ b/src/clients/http.rs @@ -43,14 +43,6 @@ impl HttpClient { &self.base_url } - pub fn into_inner(self) -> reqwest::Client { - self.client - } - - pub fn inner_as_ref(&self) -> &reqwest::Client { - &self.client - } - /// This is sectioned off to allow for testing. pub(super) async fn http_response_to_health_check_result( res: Result, diff --git a/src/clients/nlp.rs b/src/clients/nlp.rs index 67bf6cc8..bf934250 100644 --- a/src/clients/nlp.rs +++ b/src/clients/nlp.rs @@ -15,6 +15,13 @@ */ +use async_trait::async_trait; +use axum::http::HeaderMap; +use futures::{StreamExt, TryStreamExt}; +use ginepro::LoadBalancedChannel; +use tonic::{Code, Request}; +use tracing::{info, instrument}; + use super::{ create_grpc_client, errors::grpc_to_http_code, grpc_request_with_headers, BoxStream, Client, Error, @@ -35,12 +42,6 @@ use crate::{ }, tracing_utils::trace_context_from_grpc_response, }; -use async_trait::async_trait; -use axum::http::HeaderMap; -use futures::{StreamExt, TryStreamExt}; -use ginepro::LoadBalancedChannel; -use tonic::{Code, Request}; -use tracing::{info, instrument}; const DEFAULT_PORT: u16 = 8085; const MODEL_ID_HEADER_NAME: &str = "mm-model-id"; diff --git a/src/clients/openai.rs b/src/clients/openai.rs index d17114f0..9ea190eb 100644 --- a/src/clients/openai.rs +++ b/src/clients/openai.rs @@ -17,15 +17,16 @@ use std::collections::HashMap; -use super::{create_http_client, Client, Error, HttpClient}; -use crate::{ - config::ServiceConfig, health::HealthCheckResult, tracing_utils::with_traceparent_header, -}; use async_trait::async_trait; use hyper::{HeaderMap, StatusCode}; use serde::{Deserialize, Serialize}; use tracing::{info, instrument}; +use super::{create_http_client, Client, Error, HttpClient}; +use crate::{ + config::ServiceConfig, health::HealthCheckResult, tracing_utils::with_traceparent_header, +}; + const DEFAULT_PORT: u16 = 8080; #[cfg_attr(test, faux::create)] From 1ed58adc19a8ef64876076be847c23880d6caf69 Mon Sep 17 00:00:00 2001 From: Paul Scoropan <1paulscoropan@gmail.com> Date: Thu, 17 Oct 2024 14:48:07 -0400 Subject: [PATCH 41/50] tracing_utils import nit Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> --- src/clients/chunker.rs | 5 +---- src/tracing_utils.rs | 6 ++---- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/clients/chunker.rs b/src/clients/chunker.rs index 9c139b45..b2fc9acd 100644 --- a/src/clients/chunker.rs +++ b/src/clients/chunker.rs @@ -97,10 +97,7 @@ impl ChunkerClient { Box::pin(client.bidi_streaming_chunker_tokenization_task_predict(request)); let response_stream = response_stream_fut.await?; trace_context_from_grpc_response(&response_stream); - Ok(response_stream - .into_inner() - .map_err(Into::into) - .boxed()) + Ok(response_stream.into_inner().map_err(Into::into).boxed()) } } diff --git a/src/tracing_utils.rs b/src/tracing_utils.rs index b689d716..77731fc0 100644 --- a/src/tracing_utils.rs +++ b/src/tracing_utils.rs @@ -17,9 +17,7 @@ use std::time::Duration; -use axum::extract::Request; -use axum::http::HeaderMap; -use axum::response::Response; +use axum::{extract::Request, http::HeaderMap, response::Response}; use opentelemetry::{ global, metrics::MetricsError, @@ -28,12 +26,12 @@ use opentelemetry::{ }; use opentelemetry_http::{HeaderExtractor, HeaderInjector}; use opentelemetry_otlp::WithExportConfig; -use opentelemetry_sdk::propagation::TraceContextPropagator; use opentelemetry_sdk::{ metrics::{ reader::{DefaultAggregationSelector, DefaultTemporalitySelector}, SdkMeterProvider, }, + propagation::TraceContextPropagator, runtime, trace::{Config, Sampler}, Resource, From a144c8ec7b36929b8c1ae1780788747bf8cd3383 Mon Sep 17 00:00:00 2001 From: Paul Scoropan <1paulscoropan@gmail.com> Date: Thu, 17 Oct 2024 16:29:28 -0400 Subject: [PATCH 42/50] remove unneeded tracing_utils change Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> --- src/tracing_utils.rs | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/tracing_utils.rs b/src/tracing_utils.rs index 77731fc0..5a4850aa 100644 --- a/src/tracing_utils.rs +++ b/src/tracing_utils.rs @@ -248,18 +248,11 @@ pub fn incoming_request_span(request: &Request) -> Span { pub fn on_incoming_request(request: &Request, span: &Span) { let _guard = span.enter(); - let trace_id = Span::current() - .context() - .span() - .span_context() - .trace_id() - .to_string(); - println!("trace: {}", trace_id); info!( "incoming request to {} {} with trace_id {}", request.method(), request.uri().path(), - trace_id, + span.context().span().span_context().trace_id().to_string() ); info!( monotonic_counter.incoming_request_count = 1, From a0e85eed89dc47af87586eaef90bb97139f1cc30 Mon Sep 17 00:00:00 2001 From: Paul Scoropan <1paulscoropan@gmail.com> Date: Mon, 28 Oct 2024 14:38:40 -0400 Subject: [PATCH 43/50] rebase fix and nits Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> --- src/clients/detector/text_chat.rs | 2 +- src/orchestrator.rs | 6 +++--- src/orchestrator/unary.rs | 22 +++++++--------------- src/server.rs | 4 ++-- 4 files changed, 13 insertions(+), 21 deletions(-) diff --git a/src/clients/detector/text_chat.rs b/src/clients/detector/text_chat.rs index aef3a0c9..bb6fc524 100644 --- a/src/clients/detector/text_chat.rs +++ b/src/clients/detector/text_chat.rs @@ -60,13 +60,13 @@ impl TextChatDetectorClient { headers: HeaderMap, ) -> Result, Error> { let url = self.client.base_url().join(CHAT_DETECTOR_ENDPOINT).unwrap(); + info!(?url, "sending chat detector client request"); let request = self .client .post(url) .headers(headers) .header(DETECTOR_ID_HEADER_NAME, model_id) .json(&request); - info!(?url, "sending chat detector client request"); debug!("chat detector client request: {:?}", request); let response = request.send().await?; debug!("chat detector client response: {:?}", response); diff --git a/src/orchestrator.rs b/src/orchestrator.rs index dda82b7d..4fbf6ecf 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -387,7 +387,7 @@ impl ContextDocsDetectionTask { #[derive(Debug)] pub struct ChatDetectionTask { /// Request unique identifier - pub request_id: Uuid, + pub trace_id: TraceId, /// Detectors configuration pub detectors: HashMap, @@ -400,9 +400,9 @@ pub struct ChatDetectionTask { } impl ChatDetectionTask { - pub fn new(request_id: Uuid, request: ChatDetectionHttpRequest, headers: HeaderMap) -> Self { + pub fn new(trace_id: TraceId, request: ChatDetectionHttpRequest, headers: HeaderMap) -> Self { Self { - request_id, + trace_id, detectors: request.detectors, messages: request.messages, headers, diff --git a/src/orchestrator/unary.rs b/src/orchestrator/unary.rs index 9a2d1a4d..66d31d7e 100644 --- a/src/orchestrator/unary.rs +++ b/src/orchestrator/unary.rs @@ -449,12 +449,12 @@ impl Orchestrator { } /// Handles detections on chat messages (without performing generation) + #[instrument(skip_all, fields(trace_id = ?task.trace_id, headers = ?task.headers))] pub async fn handle_chat_detection( &self, task: ChatDetectionTask, ) -> Result { info!( - request_id = ?task.request_id, detectors = ?task.detectors, "handling detection on chat content task" ); @@ -491,13 +491,13 @@ impl Orchestrator { Ok(Ok(result)) => Ok(result), // Task failed, return error propagated from child task that failed Ok(Err(error)) => { - error!(request_id = ?task.request_id, %error, "detection task on chat failed"); + error!(%error, "detection task on chat failed"); Err(error) } // Task cancelled or panicked Err(error) => { let error = error.into(); - error!(request_id = ?task.request_id, %error, "detection task on chat failed"); + error!(%error, "detection task on chat failed"); Err(error) } } @@ -545,7 +545,7 @@ async fn detection_task( chunks: HashMap>, headers: HeaderMap, ) -> Result, Error> { - debug!(detectors = ?detectors.keys(), "handling detection task"); + debug!(detectors = ?detectors.keys(), "handling detection tasks"); // Spawn tasks for each detector let tasks = detectors .iter() @@ -689,11 +689,7 @@ pub async fn detect_content( Vec::default() } else { let request = ContentAnalysisRequest::new(contents, detector_params); - debug!( - ?request, - threshold, - "sending detector request" - ); + debug!(?request, threshold, "sending detector request"); let client = ctx .clients .get_as::(&detector_id) @@ -754,11 +750,7 @@ pub async fn detect_for_generation( ); let request = GenerationDetectionRequest::new(prompt.clone(), generated_text.clone(), detector_params); - debug!( - threshold, - ?request, - "sending generation detector request" - ); + debug!(threshold, ?request, "sending generation detector request"); let client = ctx .clients .get_as::(&detector_id) @@ -912,7 +904,7 @@ pub async fn chunk_parallel( chunker_id: String, text_with_offsets: Vec<(usize, String)>, ) -> Result<(String, Vec), Error> { - debug!("sending parallel chunk request"); + debug!("sending parallel chunk requests"); let chunks = stream::iter(text_with_offsets) .map(|(offset, text)| { let ctx = ctx.clone(); diff --git a/src/server.rs b/src/server.rs index 64210149..8fc0e523 100644 --- a/src/server.rs +++ b/src/server.rs @@ -446,10 +446,10 @@ async fn detect_chat( headers: HeaderMap, WithRejection(Json(request), _): WithRejection, Error>, ) -> Result { - let request_id = Uuid::new_v4(); + let trace_id = Span::current().context().span().span_context().trace_id(); request.validate_for_text()?; let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - let task = ChatDetectionTask::new(request_id, request, headers); + let task = ChatDetectionTask::new(trace_id, request, headers); match state.orchestrator.handle_chat_detection(task).await { Ok(response) => Ok(Json(response).into_response()), Err(error) => Err(error.into()), From e8d0b5de059e19a6f33de6ac1c48a073c74a27e1 Mon Sep 17 00:00:00 2001 From: Paul Scoropan <1paulscoropan@gmail.com> Date: Tue, 29 Oct 2024 14:42:43 -0400 Subject: [PATCH 44/50] doc comments Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> --- src/clients.rs | 2 ++ src/clients/chunker.rs | 2 ++ src/clients/detector.rs | 2 ++ src/clients/nlp.rs | 2 ++ src/tracing_utils.rs | 13 +++++++++++++ 5 files changed, 21 insertions(+) diff --git a/src/clients.rs b/src/clients.rs index 9e7b6bc8..77dd5aaf 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -343,6 +343,8 @@ pub fn is_valid_hostname(hostname: &str) -> bool { || hostname.len() > 253) } +/// Turns a gRPC client request body of type `T` and header map into a `tonic::Request`. +/// Will also inject the current `traceparent` header into the request based on the current span. fn grpc_request_with_headers(request: T, headers: HeaderMap) -> Request { let headers = with_traceparent_header(headers); let metadata = MetadataMap::from_headers(headers); diff --git a/src/clients/chunker.rs b/src/clients/chunker.rs index b2fc9acd..bd924b90 100644 --- a/src/clients/chunker.rs +++ b/src/clients/chunker.rs @@ -133,6 +133,8 @@ impl Client for ChunkerClient { } } +/// Turns a chunker client gRPC request body of type `T` into a `tonic::Request` with headers. +/// Adds the provided `model_id` as a header as well as injects `traceparent` from the current span. fn request_with_headers(request: T, model_id: &str) -> Request { let mut request = grpc_request_with_headers(request, HeaderMap::new()); request diff --git a/src/clients/detector.rs b/src/clients/detector.rs index 225a4ec9..cf06ddef 100644 --- a/src/clients/detector.rs +++ b/src/clients/detector.rs @@ -54,6 +54,8 @@ impl From for Error { } } +/// Make a POST request for an HTTP detector client and return the response. +/// Also injects the `traceparent` header from the current span and traces the response. pub async fn post_with_headers( client: HttpClient, url: Url, diff --git a/src/clients/nlp.rs b/src/clients/nlp.rs index bf934250..01067019 100644 --- a/src/clients/nlp.rs +++ b/src/clients/nlp.rs @@ -160,6 +160,8 @@ impl Client for NlpClient { } } +/// Turns an NLP client gRPC request body of type `T` and headers into a `tonic::Request`. +/// Also injects provided `model_id` and `traceparent` from current context into headers. fn request_with_headers(request: T, model_id: &str, headers: HeaderMap) -> Request { let mut request = grpc_request_with_headers(request, headers); request diff --git a/src/tracing_utils.rs b/src/tracing_utils.rs index 5a4850aa..42da2169 100644 --- a/src/tracing_utils.rs +++ b/src/tracing_utils.rs @@ -332,6 +332,11 @@ pub fn on_outgoing_eos(trailers: Option<&HeaderMap>, stream_duration: Duration, info!(monotonic_histogram.service_stream_response_duration = stream_duration.as_millis()); } +/// Injects the `traceparent` header into the header map from the current tracing span context. +/// Also injects empty `tracestate` header by default. This can be used to propagate +/// vendor-specific trace context. +/// Used by both gRPC and HTTP requests since `tonic::Metadata` uses `http::HeaderMap`. +/// See https://www.w3.org/TR/trace-context/#trace-context-http-headers-format. pub fn with_traceparent_header(headers: HeaderMap) -> HeaderMap { let mut headers = headers.clone(); let ctx = Span::current().context(); @@ -342,6 +347,10 @@ pub fn with_traceparent_header(headers: HeaderMap) -> HeaderMap { headers } +/// Extracts the `traceparent` header from an HTTP response's headers and uses it to set the current +/// tracing span context (i.e. use `traceparent` as parent to the current span). +/// Defaults to using the current context when no `traceparent` is found. +/// See https://www.w3.org/TR/trace-context/#trace-context-http-headers-format. pub fn trace_context_from_http_response(response: &reqwest::Response) { let ctx = global::get_text_map_propagator(|propagator| { // Returns the current context if no `traceparent` is found @@ -350,6 +359,10 @@ pub fn trace_context_from_http_response(response: &reqwest::Response) { Span::current().set_parent(ctx); } +/// Extracts the `traceparent` header from a gRPC response's metadata and uses it to set the current +/// tracing span context (i.e. use `traceparent` as parent to the current span). +/// Defaults to using the current context when no `traceparent` is found. +/// See https://www.w3.org/TR/trace-context/#trace-context-http-headers-format. pub fn trace_context_from_grpc_response(response: &tonic::Response) { let ctx = global::get_text_map_propagator(|propagator| { let metadata = response.metadata().clone(); From 521c80f48c38a142e4a15c336c7eeb499c622f03 Mon Sep 17 00:00:00 2001 From: Dan Clark <44146800+declark1@users.noreply.github.com> Date: Wed, 30 Oct 2024 13:30:48 -0700 Subject: [PATCH 45/50] Add OpenAiClient stream handling, update types (#230) * Add initial stream handling for OpenAiClient Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> * Update OpenAiClient types and fix streaming request Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> * Add OpenAiError and parse error message, order dependencies Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> * Drop openai client tests module Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> * Telemetry rebase and updates Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> * Return SSE events directly from OpenAiClient stream Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> * Add headers to OpenAiClient chat_completions method Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> --------- Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> --- Cargo.lock | 49 +++++++++++ Cargo.toml | 1 + src/clients/openai.rs | 188 +++++++++++++++++++++++++++++++++++------- src/models.rs | 6 +- 4 files changed, 213 insertions(+), 31 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c558bf93..4bbb6417 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -555,6 +555,17 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "eventsource-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" +dependencies = [ + "futures-core", + "nom", + "pin-project-lite", +] + [[package]] name = "fastrand" version = "2.1.1" @@ -614,6 +625,7 @@ dependencies = [ "opentelemetry_sdk", "prost", "reqwest", + "reqwest-eventsource", "rustls", "rustls-pemfile", "rustls-webpki", @@ -743,6 +755,12 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.30" @@ -1846,15 +1864,33 @@ dependencies = [ "tokio", "tokio-native-tls", "tokio-rustls", + "tokio-util", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "webpki-roots", "windows-registry", ] +[[package]] +name = "reqwest-eventsource" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde" +dependencies = [ + "eventsource-stream", + "futures-core", + "futures-timer", + "mime", + "nom", + "pin-project-lite", + "reqwest", + "thiserror", +] + [[package]] name = "reserve-port" version = "2.0.1" @@ -2834,6 +2870,19 @@ version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" +[[package]] +name = "wasm-streams" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e072d4e72f700fb3443d8fe94a39315df013eef1104903cdb0a2abd322bbecd" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.70" diff --git a/Cargo.toml b/Cargo.toml index 2d566564..310c2c62 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ opentelemetry-otlp = { version = "0.17.0", features = ["http-proto"] } opentelemetry_sdk = { version = "0.24.1", features = ["rt-tokio", "metrics"] } prost = "0.13.1" reqwest = { version = "0.12.5", features = ["blocking", "rustls-tls", "json"] } +reqwest-eventsource = "0.6.0" rustls = {version = "0.23.12", default-features = false, features = ["std"]} rustls-pemfile = "2.1.3" rustls-webpki = "0.102.6" diff --git a/src/clients/openai.rs b/src/clients/openai.rs index 9ea190eb..7abdac21 100644 --- a/src/clients/openai.rs +++ b/src/clients/openai.rs @@ -15,11 +15,15 @@ */ -use std::collections::HashMap; +use std::{collections::HashMap, convert::Infallible}; use async_trait::async_trait; +use axum::response::sse; +use futures::StreamExt; use hyper::{HeaderMap, StatusCode}; +use reqwest_eventsource::{Event, RequestBuilderExt}; use serde::{Deserialize, Serialize}; +use tokio::sync::mpsc; use tracing::{info, instrument}; use super::{create_http_client, Client, Error, HttpClient}; @@ -55,23 +59,62 @@ impl OpenAiClient { pub async fn chat_completions( &self, request: ChatCompletionRequest, + headers: HeaderMap, ) -> Result { let url = self.client.base_url().join("/v1/chat/completions").unwrap(); - let headers = with_traceparent_header(HeaderMap::new()); + let headers = with_traceparent_header(headers); + let stream = request.stream.unwrap_or_default(); info!(?url, ?headers, ?request, "sending client request"); - let response = self - .client - .post(url) - .headers(headers) - .json(&request) - .send() - .await?; - match response.status() { - StatusCode::OK => Ok(response.json().await?), - _ => Err(Error::Http { - code: response.status(), - message: "".into(), // TODO - }), + if stream { + let (tx, rx) = mpsc::channel(32); + let mut event_stream = self + .client + .post(url) + .headers(headers) + .json(&request) + .eventsource() + .unwrap(); + // Spawn task to forward events to receiver + tokio::spawn(async move { + while let Some(result) = event_stream.next().await { + match result { + Ok(event) => { + if let Event::Message(message) = event { + let event = sse::Event::default().data(message.data); + let _ = tx.send(Ok(event)).await; + } + } + Err(reqwest_eventsource::Error::StreamEnded) => break, + Err(error) => { + // We received an error from the event stream, send an error event + let event = + sse::Event::default().event("error").data(error.to_string()); + let _ = tx.send(Ok(event)).await; + } + } + } + }); + Ok(ChatCompletionResponse::Streaming(rx)) + } else { + let response = self + .client + .post(url) + .headers(headers) + .json(&request) + .send() + .await?; + match response.status() { + StatusCode::OK => Ok(response.json::().await?.into()), + _ => { + let code = response.status(); + let message = if let Ok(response) = response.json::().await { + response.message + } else { + "unknown error occurred".into() + }; + Err(Error::Http { code, message }) + } + } } } } @@ -92,7 +135,19 @@ impl Client for OpenAiClient { } } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug)] +pub enum ChatCompletionResponse { + Unary(ChatCompletion), + Streaming(mpsc::Receiver>), +} + +impl From for ChatCompletionResponse { + fn from(value: ChatCompletion) -> Self { + Self::Unary(value) + } +} + +#[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct ChatCompletionRequest { /// A list of messages comprising the conversation so far. pub messages: Vec, @@ -290,7 +345,7 @@ pub struct JsonSchemaObject { pub required: Option>, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct Message { /// The role of the messages author. pub role: String, @@ -315,16 +370,60 @@ pub struct Message { #[serde(untagged)] pub enum Content { /// The text contents of the message. - String(String), + Text(String), /// Array of content parts. Array(Vec), } -#[derive(Debug, Clone, Serialize, Deserialize)] +impl From for Content { + fn from(value: String) -> Self { + Content::Text(value) + } +} + +impl From<&str> for Content { + fn from(value: &str) -> Self { + Content::Text(value.to_string()) + } +} + +impl From> for Content { + fn from(value: Vec) -> Self { + Content::Array(value) + } +} + +impl From for ContentPart { + fn from(value: String) -> Self { + ContentPart { + r#type: ContentType::Text, + text: Some(value), + image_url: None, + refusal: None, + } + } +} + +impl From> for Content { + fn from(value: Vec) -> Self { + Content::Array(value.into_iter().map(|v| v.into()).collect()) + } +} + +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +pub enum ContentType { + #[serde(rename = "text")] + #[default] + Text, + #[serde(rename = "image_url")] + ImageUrl, +} + +#[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct ContentPart { /// The type of the content part. #[serde(rename = "type")] - pub r#type: String, + pub r#type: ContentType, /// Text content #[serde(skip_serializing_if = "Option::is_none")] pub text: Option, @@ -367,7 +466,7 @@ pub struct Function { /// Represents a chat completion response returned by model, based on the provided input. #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ChatCompletionResponse { +pub struct ChatCompletion { /// A unique identifier for the chat completion. pub id: String, /// A list of chat completion choices. Can be more than one if n is greater than 1. @@ -378,10 +477,11 @@ pub struct ChatCompletionResponse { pub model: String, /// The service tier used for processing the request. /// This field is only included if the `service_tier` parameter is specified in the request. + #[serde(skip_serializing_if = "Option::is_none")] pub service_tier: Option, /// This fingerprint represents the backend configuration that the model runs with. - #[serde(default)] - pub system_fingerprint: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, /// The object type, which is always `chat.completion`. pub object: String, /// Usage statistics for the completion request. @@ -407,8 +507,8 @@ pub struct ChatCompletionMessage { /// The contents of the message. pub content: Option, /// The refusal message generated by the model. + #[serde(skip_serializing_if = "Option::is_none")] pub refusal: Option, - #[serde(default)] pub tool_calls: Vec, /// The role of the author of this message. pub role: String, @@ -429,8 +529,10 @@ pub struct ChatCompletionLogprob { pub token: String, /// The log probability of this token. pub logprob: f32, + #[serde(skip_serializing_if = "Option::is_none")] pub bytes: Option>, /// List of the most likely tokens and their log probability, at this token position. + #[serde(skip_serializing_if = "Option::is_none")] pub top_logprobs: Option>, } @@ -455,9 +557,11 @@ pub struct ChatCompletionChunk { pub model: String, /// The service tier used for processing the request. /// This field is only included if the service_tier parameter is specified in the request. + #[serde(skip_serializing_if = "Option::is_none")] pub service_tier: Option, /// This fingerprint represents the backend configuration that the model runs with. - pub system_fingerprint: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, /// The object type, which is always `chat.completion.chunk`. pub object: String, #[serde(skip_serializing_if = "Option::is_none")] @@ -467,7 +571,7 @@ pub struct ChatCompletionChunk { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionChunkChoice { /// A chat completion delta generated by streamed model responses. - pub delta: ChatCompletionMessage, + pub delta: ChatCompletionDelta, /// Log probability information for the choice. pub logprobs: Option, /// The reason the model stopped generating tokens. @@ -476,6 +580,22 @@ pub struct ChatCompletionChunkChoice { pub index: u32, } +/// A chat completion delta generated by streamed model responses. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionDelta { + /// The contents of the message. + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + /// The refusal message generated by the model. + #[serde(skip_serializing_if = "Option::is_none")] + pub refusal: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + pub tool_calls: Vec, + /// The role of the author of this message. + #[serde(skip_serializing_if = "Option::is_none")] + pub role: Option, +} + /// Usage statistics for a completion. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Usage { @@ -486,9 +606,11 @@ pub struct Usage { /// Total number of tokens used in the request (prompt + completion). pub total_tokens: u32, /// Breakdown of tokens used in a completion. - pub completion_token_details: CompletionTokenDetails, + #[serde(skip_serializing_if = "Option::is_none")] + pub completion_token_details: Option, /// Breakdown of tokens used in the prompt. - pub prompt_token_details: PromptTokenDetails, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_token_details: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -509,3 +631,13 @@ pub enum StopTokens { Array(Vec), String(String), } + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpenAiError { + pub object: Option, + pub message: String, + #[serde(rename = "type")] + pub r#type: Option, + pub param: Option, + pub code: u16, +} diff --git a/src/models.rs b/src/models.rs index dee1ffc5..68b790a0 100644 --- a/src/models.rs +++ b/src/models.rs @@ -25,7 +25,7 @@ use crate::{ clients::{ self, detector::{ContentAnalysisResponse, ContextType}, - openai::Content, + openai::{Content, ContentType}, }, health::HealthCheckCache, pb, @@ -998,7 +998,7 @@ impl ChatDetectionHttpRequest { match content { Content::Array(content) => { for content_part in content { - if content_part.r#type != "text" { + if !matches!(content_part.r#type, ContentType::Text) { return Err(ValidationError::Invalid( "Only content of type text is allowed".into(), )); @@ -1006,7 +1006,7 @@ impl ChatDetectionHttpRequest { } Ok(()) } - Content::String(_) => Ok(()), // if message.content is a string, it is a valid message + Content::Text(_) => Ok(()), // if message.content is a string, it is a valid message } } } From f2010e1d8b621e0202868bbd6bbf81deb747b0d3 Mon Sep 17 00:00:00 2001 From: Chris Santiago Date: Mon, 4 Nov 2024 15:43:53 -0600 Subject: [PATCH 46/50] added error messages to various unwraps (#243) * added error messages to various unwraps Signed-off-by: resoluteCoder * added detector id vars to expect Signed-off-by: resoluteCoder * changed expects to unwrap or else due to lint Signed-off-by: resoluteCoder --------- Signed-off-by: resoluteCoder --- src/clients.rs | 7 +++++-- src/orchestrator/streaming.rs | 13 ++++++++++--- src/orchestrator/streaming/aggregator.rs | 4 +++- src/orchestrator/unary.rs | 24 ++++++++++++++++++++---- src/server.rs | 14 +++++++++++--- 5 files changed, 49 insertions(+), 13 deletions(-) diff --git a/src/clients.rs b/src/clients.rs index 77dd5aaf..4aa355b0 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -204,8 +204,11 @@ pub async fn create_http_client(default_port: u16, service_config: &ServiceConfi Some(_) => "https", None => "http", }; - let mut base_url = Url::parse(&format!("{}://{}", protocol, &service_config.hostname)).unwrap(); - base_url.set_port(Some(port)).unwrap(); + let mut base_url = Url::parse(&format!("{}://{}", protocol, &service_config.hostname)) + .unwrap_or_else(|e| panic!("error parsing base url: {}", e)); + base_url + .set_port(Some(port)) + .unwrap_or_else(|_| panic!("error setting port: {}", port)); debug!(%base_url, "creating HTTP client"); let request_timeout = Duration::from_secs( service_config diff --git a/src/orchestrator/streaming.rs b/src/orchestrator/streaming.rs index be71648e..861ab354 100644 --- a/src/orchestrator/streaming.rs +++ b/src/orchestrator/streaming.rs @@ -271,11 +271,18 @@ async fn streaming_output_detection_task( // Create a mutable copy of the parameters, so that we can modify it based on processing let mut detector_params = detector_params.clone(); let detector_id = detector_id.to_string(); - let chunker_id = ctx.config.get_chunker_id(&detector_id).unwrap(); + let chunker_id = ctx + .config + .get_chunker_id(&detector_id) + .expect("chunker id is not found"); // Get the detector config // TODO: Add error handling - let detector_config = ctx.config.detectors.get(&detector_id).unwrap(); + let detector_config = ctx + .config + .detectors + .get(&detector_id) + .expect("detector config not found"); // Get the default threshold to use if threshold is not provided by the user let default_threshold = detector_config.default_threshold; @@ -394,7 +401,7 @@ async fn detection_task( let client = ctx .clients .get_as::(&detector_id) - .unwrap(); + .unwrap_or_else(|| panic!("text contents detector client not found for {}", detector_id)); match client.text_contents(&detector_id, request, headers) .await .map_err(|error| Error::DetectorRequestFailed { id: detector_id.clone(), error }) { diff --git a/src/orchestrator/streaming/aggregator.rs b/src/orchestrator/streaming/aggregator.rs index 97c83d64..c1f6fc72 100644 --- a/src/orchestrator/streaming/aggregator.rs +++ b/src/orchestrator/streaming/aggregator.rs @@ -142,7 +142,9 @@ impl ResultActor { result.token_classification_results.output = Some(detections); if input_start_index == 0 { // Get input_token_count and seed from first generation message - let first = generations.first().unwrap(); + let first = generations + .first() + .expect("first element in classified generated text stream result not found"); result.input_token_count = first.input_token_count; result.seed = first.seed; // Get input_tokens from second generation message (if specified) diff --git a/src/orchestrator/unary.rs b/src/orchestrator/unary.rs index 66d31d7e..bb049447 100644 --- a/src/orchestrator/unary.rs +++ b/src/orchestrator/unary.rs @@ -270,13 +270,19 @@ impl Orchestrator { let ctx = ctx.clone(); let detector_id = detector_id.clone(); let detector_params = detector_params.clone(); - let detector_config = ctx.config.detectors.get(&detector_id).unwrap(); + let detector_config = + ctx.config.detectors.get(&detector_id).unwrap_or_else(|| { + panic!("detector config not found for {}", detector_id) + }); let chunker_id = detector_config.chunker_id.as_str(); let default_threshold = detector_config.default_threshold; - let chunk = chunks.get(chunker_id).unwrap().clone(); + let chunk = chunks + .get(chunker_id) + .unwrap_or_else(|| panic!("chunk not found for {}", chunker_id)) + .clone(); let headers = headers.clone(); @@ -754,7 +760,12 @@ pub async fn detect_for_generation( let client = ctx .clients .get_as::(&detector_id) - .unwrap(); + .unwrap_or_else(|| { + panic!( + "text generation detector client not found for {}", + detector_id + ) + }); let response = client .text_generation(&detector_id, request, headers) .await @@ -845,7 +856,12 @@ pub async fn detect_for_context( let client = ctx .clients .get_as::(&detector_id) - .unwrap(); + .unwrap_or_else(|| { + panic!( + "text context doc detector client not found for {}", + detector_id + ) + }); let response = client .text_context_doc(&detector_id, request, headers) .await diff --git a/src/server.rs b/src/server.rs index 8fc0e523..7907dc86 100644 --- a/src/server.rs +++ b/src/server.rs @@ -129,15 +129,23 @@ pub async fn run( // Configure mTLS if client CA is provided let client_auth = if tls_client_ca_cert_path.is_some() { info!("Configuring TLS trust certificate (mTLS) for incoming connections"); - let client_certs = load_certs(tls_client_ca_cert_path.as_ref().unwrap()); + let client_certs = load_certs( + tls_client_ca_cert_path + .as_ref() + .expect("error loading certs for mTLS"), + ); let mut client_auth_certs = RootCertStore::empty(); for client_cert in client_certs { // Should be only one - client_auth_certs.add(client_cert).unwrap(); + client_auth_certs + .add(client_cert.clone()) + .unwrap_or_else(|e| { + panic!("error adding client cert {:?}: {}", client_cert, e) + }); } WebPkiClientVerifier::builder(client_auth_certs.into()) .build() - .unwrap() + .unwrap_or_else(|e| panic!("error building client verifier: {}", e)) } else { WebPkiClientVerifier::no_client_auth() }; From c15b2a2619f4ec9b9f5b4d8352e045e28457c687 Mon Sep 17 00:00:00 2001 From: Dan Clark <44146800+declark1@users.noreply.github.com> Date: Tue, 5 Nov 2024 09:53:18 -0800 Subject: [PATCH 47/50] Implement Chat Completions API (#240) * Implement Chat Completions API Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> * Conditionally enable chat completions endpoint Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> * Update chat completions to chat completions detection, rename items for alignment Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> --------- Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> --- src/clients/openai.rs | 12 ++-- src/orchestrator.rs | 23 ++++++- .../chat_completions_detection.rs | 20 ++++++ src/server.rs | 62 +++++++++++++++---- 4 files changed, 98 insertions(+), 19 deletions(-) create mode 100644 src/orchestrator/chat_completions_detection.rs diff --git a/src/clients/openai.rs b/src/clients/openai.rs index 7abdac21..a626d111 100644 --- a/src/clients/openai.rs +++ b/src/clients/openai.rs @@ -58,9 +58,9 @@ impl OpenAiClient { #[instrument(skip_all, fields(request.model))] pub async fn chat_completions( &self, - request: ChatCompletionRequest, + request: ChatCompletionsRequest, headers: HeaderMap, - ) -> Result { + ) -> Result { let url = self.client.base_url().join("/v1/chat/completions").unwrap(); let headers = with_traceparent_header(headers); let stream = request.stream.unwrap_or_default(); @@ -94,7 +94,7 @@ impl OpenAiClient { } } }); - Ok(ChatCompletionResponse::Streaming(rx)) + Ok(ChatCompletionsResponse::Streaming(rx)) } else { let response = self .client @@ -136,19 +136,19 @@ impl Client for OpenAiClient { } #[derive(Debug)] -pub enum ChatCompletionResponse { +pub enum ChatCompletionsResponse { Unary(ChatCompletion), Streaming(mpsc::Receiver>), } -impl From for ChatCompletionResponse { +impl From for ChatCompletionsResponse { fn from(value: ChatCompletion) -> Self { Self::Unary(value) } } #[derive(Debug, Default, Clone, Serialize, Deserialize)] -pub struct ChatCompletionRequest { +pub struct ChatCompletionsRequest { /// A list of messages comprising the conversation so far. pub messages: Vec, /// ID of the model to use. diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 4fbf6ecf..a7fa465d 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -17,6 +17,7 @@ pub mod errors; pub use errors::Error; +pub mod chat_completions_detection; pub mod streaming; pub mod unary; @@ -35,7 +36,7 @@ use crate::{ text_context_doc::ContextType, TextChatDetectorClient, TextContextDocDetectorClient, TextGenerationDetectorClient, }, - openai::OpenAiClient, + openai::{ChatCompletionsRequest, OpenAiClient}, ClientMap, GenerationClient, NlpClient, TextContentsDetectorClient, TgisClient, }, config::{DetectorType, GenerationProvider, OrchestratorConfig}, @@ -469,6 +470,26 @@ impl StreamingClassificationWithGenTask { } } +#[derive(Debug)] +pub struct ChatCompletionsDetectionTask { + /// Unique identifier of request trace + pub trace_id: TraceId, + /// Chat completion request + pub request: ChatCompletionsRequest, + // Headermap + pub headers: HeaderMap, +} + +impl ChatCompletionsDetectionTask { + pub fn new(trace_id: TraceId, request: ChatCompletionsRequest, headers: HeaderMap) -> Self { + Self { + trace_id, + request, + headers, + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/orchestrator/chat_completions_detection.rs b/src/orchestrator/chat_completions_detection.rs new file mode 100644 index 00000000..dd06fab0 --- /dev/null +++ b/src/orchestrator/chat_completions_detection.rs @@ -0,0 +1,20 @@ +use tracing::{info, instrument}; + +use super::{ChatCompletionsDetectionTask, Error, Orchestrator}; +use crate::clients::openai::{ChatCompletionsResponse, OpenAiClient}; + +impl Orchestrator { + #[instrument(skip_all, fields(trace_id = ?task.trace_id, headers = ?task.headers))] + pub async fn handle_chat_completions_detection( + &self, + task: ChatCompletionsDetectionTask, + ) -> Result { + info!("handling chat completions detection task"); + let client = self + .ctx + .clients + .get_as::("chat_generation") + .expect("chat_generation client not found"); + Ok(client.chat_completions(task.request, task.headers).await?) + } +} diff --git a/src/server.rs b/src/server.rs index 7907dc86..e7bd12d6 100644 --- a/src/server.rs +++ b/src/server.rs @@ -44,6 +44,7 @@ use opentelemetry::trace::TraceContextExt; use rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig}; use tokio::{net::TcpListener, signal}; use tokio_rustls::TlsAcceptor; +use tokio_stream::wrappers::ReceiverStream; use tower_http::trace::TraceLayer; use tower_service::Service; use tracing::{debug, error, info, instrument, warn, Span}; @@ -51,11 +52,12 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; use webpki::types::{CertificateDer, PrivateKeyDer}; use crate::{ + clients::openai::{ChatCompletionsRequest, ChatCompletionsResponse}, models::{self, InfoParams, InfoResponse}, orchestrator::{ - self, ChatDetectionTask, ClassificationWithGenTask, ContextDocsDetectionTask, - DetectionOnGenerationTask, GenerationWithDetectionTask, Orchestrator, - StreamingClassificationWithGenTask, TextContentDetectionTask, + self, ChatCompletionsDetectionTask, ChatDetectionTask, ClassificationWithGenTask, + ContextDocsDetectionTask, DetectionOnGenerationTask, GenerationWithDetectionTask, + Orchestrator, StreamingClassificationWithGenTask, TextContentDetectionTask, }, tracing_utils, }; @@ -160,7 +162,7 @@ pub async fn run( } // (2b) Add main guardrails server routes - let app = Router::new() + let mut router = Router::new() .route( &format!("{}/classification-with-text-generation", API_PREFIX), post(classification_with_gen), @@ -191,16 +193,25 @@ pub async fn run( .route( &format!("{}/detection/generated", TEXT_API_PREFIX), post(detect_generated), - ) - .with_state(shared_state) - .layer( - TraceLayer::new_for_http() - .make_span_with(tracing_utils::incoming_request_span) - .on_request(tracing_utils::on_incoming_request) - .on_response(tracing_utils::on_outgoing_response) - .on_eos(tracing_utils::on_outgoing_eos), ); + // If chat generation is configured, enable the chat completions detection endpoint. + if shared_state.orchestrator.config().chat_generation.is_some() { + info!("Enabling chat completions detection endpoint"); + router = router.route( + "/api/v2/chat/completions-detection", + post(chat_completions_detection), + ); + } + + let app = router.with_state(shared_state).layer( + TraceLayer::new_for_http() + .make_span_with(tracing_utils::incoming_request_span) + .on_request(tracing_utils::on_incoming_request) + .on_response(tracing_utils::on_outgoing_response) + .on_eos(tracing_utils::on_outgoing_eos), + ); + // (2c) Generate main guardrails server handle based on whether TLS is needed let listener: TcpListener = TcpListener::bind(&http_addr) .await @@ -488,6 +499,33 @@ async fn detect_generated( } } +#[instrument(skip_all)] +async fn chat_completions_detection( + State(state): State>, + headers: HeaderMap, + WithRejection(Json(request), _): WithRejection, Error>, +) -> Result { + let trace_id = Span::current().context().span().span_context().trace_id(); + info!(?trace_id, "handling request"); + let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); + let task = ChatCompletionsDetectionTask::new(trace_id, request, headers); + match state + .orchestrator + .handle_chat_completions_detection(task) + .await + { + Ok(response) => match response { + ChatCompletionsResponse::Unary(response) => Ok(Json(response).into_response()), + ChatCompletionsResponse::Streaming(response_rx) => { + let response_stream = ReceiverStream::new(response_rx); + let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); + Ok(sse.into_response()) + } + }, + Err(error) => Err(error.into()), + } +} + /// Shutdown signal handler async fn shutdown_signal() { let ctrl_c = async { From b47c0c9df219850816f5f880849560776820f7e8 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Thu, 7 Nov 2024 09:41:19 -0700 Subject: [PATCH 48/50] :memo: Streaming content API (#246) * :memo: Document tags for detectors API Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :construction: Start stream content API and ADR Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :memo: Update content stream response Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :memo: Document content for request event Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :memo: Update stream content API and decisions Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :truck: Rename field Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :memo: Update types and clarify indices Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --------- Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- docs/api/openapi_detector_api.yaml | 11 +++ docs/api/orchestrator_openapi_0_1_0.yaml | 98 ++++++++++++++++--- .../adrs/005-chat-completion-support.md | 2 +- .../adrs/007-orchestrator-telemetry.md | 2 +- .../008-streaming-orchestrator-endpoints.md | 42 ++++++++ 5 files changed, 141 insertions(+), 14 deletions(-) create mode 100644 docs/architecture/adrs/008-streaming-orchestrator-endpoints.md diff --git a/docs/api/openapi_detector_api.yaml b/docs/api/openapi_detector_api.yaml index 9bd96f0d..fb64ea23 100644 --- a/docs/api/openapi_detector_api.yaml +++ b/docs/api/openapi_detector_api.yaml @@ -5,9 +5,14 @@ info: name: Apache 2.0 url: https://www.apache.org/licenses/LICENSE-2.0.html version: 0.0.1 +tags: + - name: Text + description: Detections on text paths: /api/v1/text/contents: post: + tags: + - Text summary: Text Content Analysis Unary Handler description: >- Detectors that work on content text, be it user prompt or generated @@ -67,6 +72,8 @@ paths: $ref: '#/components/schemas/Error' /api/v1/text/generation: post: + tags: + - Text summary: Generation Analysis Unary Handler description: >- Detectors that run on prompt and text generation output.
@@ -115,6 +122,8 @@ paths: $ref: '#/components/schemas/Error' /api/v1/text/chat: post: + tags: + - Text summary: Chat Analysis Unary Handler description: >- Detectors that analyze chat messages and provide detections
@@ -162,6 +171,8 @@ paths: $ref: '#/components/schemas/Error' /api/v1/text/context/doc: post: + tags: + - Text summary: Context Analysis Unary Handler description: >- Detectors that work on a context created by document(s).
diff --git a/docs/api/orchestrator_openapi_0_1_0.yaml b/docs/api/orchestrator_openapi_0_1_0.yaml index 9b375f74..6b215405 100644 --- a/docs/api/orchestrator_openapi_0_1_0.yaml +++ b/docs/api/orchestrator_openapi_0_1_0.yaml @@ -103,19 +103,19 @@ paths: "200": description: Successful Response content: - application/json: + text/event-stream: schema: $ref: "#/components/schemas/ClassifiedGeneratedTextStreamResult" "404": description: Resource Not Found content: - application/json: + text/event-stream: schema: $ref: "#/components/schemas/Error" "422": description: Validation Error content: - application/json: + text/event-stream: schema: $ref: "#/components/schemas/Error" /api/v2/text/generation-detection: @@ -183,6 +183,56 @@ paths: application/json: schema: $ref: "#/components/schemas/Error" + /api/v2/text/detection/stream-content: + post: + tags: + - Task - Detection + summary: Detection task on input content stream + operationId: >- + api_v2_detection_text_content_bidi_stream_handler + requestBody: + content: + application/x-ndjson: + schema: + oneOf: + - $ref: "#/components/schemas/DetectionContentRequest" + - $ref: "#/components/schemas/DetectionContentStreamEvent" + # In OpenAPI 3.0, examples cannot be present in schemas, + # whereas object level examples are present in 3.1 + examples: + first_event: + summary: First text event with detectors + value: + detectors: + hap-v1-model-en: {} + content: "my text here" + text: + summary: Regular text event + value: + content: "my text here" + required: true + responses: + "200": + description: Successful Response + content: + text/event-stream: + schema: + # NOTE: This endpoint, like the + # `server-streaming-classification-with-text-generation` + # endpoint will produce streamed events + $ref: "#/components/schemas/DetectionContentStreamEvent" + "404": + description: Resource Not Found + content: + text/event-stream: + schema: + $ref: "#/components/schemas/Error" + "422": + description: Validation Error + content: + text/event-stream: + schema: + $ref: "#/components/schemas/Error" /api/v2/text/detection/chat: post: tags: @@ -361,6 +411,7 @@ components: content: type: string title: Content + example: "my text here" required: ["detectors", "content"] additionalProperties: false type: object @@ -397,15 +448,38 @@ components: title: Score title: Content Detection Response Object example: - - { - "start": 0, - "end": 20, - "text": "string", - "detection_type": "HAP", - "detection": "has_HAP", - "detector_id": "hap-v1-model-en", - "score": 0.999, - } + start: 0 + end: 20 + text: "string" + detection_type: "HAP" + detection: "has_HAP" + detector_id: "hap-v1-model-en" + score: 0.999 + DetectionContentStreamEvent: + properties: + content: + type: string + title: Content + example: "my text here" + required: ["content"] + type: object + description: Individual stream event + title: Content Detection Stream Event + DetectionContentStreamResponse: + properties: + detections: + type: array + items: + $ref: "#/components/schemas/DetectionContentResponseObject" + processed_index: + anyOf: + - type: integer + title: Processed Index + start_index: + type: integer + title: Start Index + type: object + title: Content Detection Stream Response DetectionChatRequest: properties: diff --git a/docs/architecture/adrs/005-chat-completion-support.md b/docs/architecture/adrs/005-chat-completion-support.md index d04e92da..6fb5156a 100644 --- a/docs/architecture/adrs/005-chat-completion-support.md +++ b/docs/architecture/adrs/005-chat-completion-support.md @@ -73,4 +73,4 @@ This means that the orchestrator will have to be able to track chunking and dete ## Status -Proposed +Accepted diff --git a/docs/architecture/adrs/007-orchestrator-telemetry.md b/docs/architecture/adrs/007-orchestrator-telemetry.md index 00690512..19faffac 100644 --- a/docs/architecture/adrs/007-orchestrator-telemetry.md +++ b/docs/architecture/adrs/007-orchestrator-telemetry.md @@ -71,7 +71,7 @@ responses back to the end user (or expect `traceparent` in incoming requests) fo ## Status -Proposed +Accepted ## Consequences diff --git a/docs/architecture/adrs/008-streaming-orchestrator-endpoints.md b/docs/architecture/adrs/008-streaming-orchestrator-endpoints.md new file mode 100644 index 00000000..418c25ac --- /dev/null +++ b/docs/architecture/adrs/008-streaming-orchestrator-endpoints.md @@ -0,0 +1,42 @@ +# ADR 008: Streaming orchestrator endpoints + +This ADR documents the patterns and behavior expected for streaming orchestrator endpoints. + +The orchestrator API can be found [at these github pages](https://foundation-model-stack.github.io/fms-guardrails-orchestrator/). + +## Motivation + +In [ADR 004](./004-orchestrator-input-only-api-design.md), the design of "input only" detection endpoints was detailed. Currently, those endpoints could only support the "unary" case, where the entire input text is available upfront. For flexibility (example: text is streamed from a generative model that may be available but uncallable through the endpoints with generation), users may still want to call detections on streamed input text. + +The orchestrator will then need to support "bidirectional streaming" endpoints, where text (whether tokens, words, sentences) is streamed in, detectors are invoked (and call their respective chunkers, using bidirectional streaming), and text processed with detectors including potential detections is streamed back to the user. + + +## Decisions + +### Server streaming or endpoint output streaming +"Server streaming" endpoints existed already prior to the writing of this particular ADR. Streaming response aggregation behavior is documented in [ADR 002](./002-streaming-response-aggregation.md). Data will continue to be streamed back with `data` events, with errors included as `event: error` per the [SSE event format](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format). + +Parameters in each response event such as `start_index` and `processed_index` will indicate to the user how much of the input stream has been processed for detections, as there might not necessarily be results like positive `detections` for certain portions of the input stream. The `start_index` and `processed_index` will be relative to the entire stream. + +### Client streaming or endpoint input streaming +- Any information needed for an entire request, like `detectors` that any detection endpoints will work on, will be expected to be present in the first event of a stream. The structure of stream events expected will be documented for each endpoint. + - An alternate consideration was using query or path parameters for information needed for an entire request, like `detectors`, but this would be complicated for the nesting that `detectors` require currently, with a mapping of each detector to dictionary parameters. + - Another alternate consideration was expecting multipart requests, one part with information for the entire request like `detectors` and another part with individual stream events. However, here the content type accepted by the request would have to change. +- Stream closing will be the expected indication that stream events have ended. + - An alternate consideration is an explicit "end of stream" request message for each endpoint, for the user to indicate the connection should be closed. For example for the OpenAI chat completions API, this looks like a `[DONE]` event. The downside here is that this particular event's contents will have to be identified and processed differently from other events. + +### Separate streaming detection endpoints + +To be clear to users, we will start with endpoints that indicate `stream` in the endpoint name. We want to avoid adding `stream` parameters in the request body since this will increase the maintenance of parameters on each request event in the streaming case. Additionally, as detailed earlier, stream endpoint responses will tend to have additional or potentially different fields than their unary counterparts. This point can be altered based on sufficient user feedback. + +NOTE: This ADR will not prescribe implementation details, but while the underlying implementation _could_ use [websockets](https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API), we are explicitly not following the patterns of some websocket APIs that require connecting and disconnecting. + +## Consequences + +- Stream detection endpoints will be separate from current "unary" ones that take entire inputs and return one response. Users then must change endpoints for this different use case. +- The orchestrator can support input or client streaming in a consistent manner. This will enable orchestrator users that may want to stream input content from other sources, like their own generative model. +- Users have to be aware that for input streaming, the first event may need to contain more information necessary for the endpoint. Thus the event message structure may not be exactly the same across events in the stream. + +## Status + +Proposed From 774cdaa5808e77f23657c36e49952d5283c0ecf2 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Mon, 11 Nov 2024 09:57:27 -0800 Subject: [PATCH 49/50] :bug: Update stream content API with intended response (#251) Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- docs/api/orchestrator_openapi_0_1_0.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api/orchestrator_openapi_0_1_0.yaml b/docs/api/orchestrator_openapi_0_1_0.yaml index 6b215405..cd27db1d 100644 --- a/docs/api/orchestrator_openapi_0_1_0.yaml +++ b/docs/api/orchestrator_openapi_0_1_0.yaml @@ -220,7 +220,7 @@ paths: # NOTE: This endpoint, like the # `server-streaming-classification-with-text-generation` # endpoint will produce streamed events - $ref: "#/components/schemas/DetectionContentStreamEvent" + $ref: "#/components/schemas/DetectionContentStreamResponse" "404": description: Resource Not Found content: From f5a877a39d2e49a5b8b42327ecc25c0266130590 Mon Sep 17 00:00:00 2001 From: Christina Xu Date: Tue, 19 Nov 2024 16:23:42 -0500 Subject: [PATCH 50/50] Add client for Caikit-NLP on http protocol --- src/clients.rs | 3 + src/clients/generation.rs | 107 +++++++++++++++++++++- src/clients/nlp_http.rs | 184 ++++++++++++++++++++++++++++++++++++++ src/config.rs | 2 + src/orchestrator.rs | 7 +- 5 files changed, 300 insertions(+), 3 deletions(-) create mode 100644 src/clients/nlp_http.rs diff --git a/src/clients.rs b/src/clients.rs index 4aa355b0..b7964934 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -58,6 +58,9 @@ pub use nlp::NlpClient; pub mod generation; pub use generation::GenerationClient; +pub mod nlp_http; +pub use nlp_http::NlpClientHttp; + pub mod openai; const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(60); diff --git a/src/clients/generation.rs b/src/clients/generation.rs index afd54e59..c5068dc1 100644 --- a/src/clients/generation.rs +++ b/src/clients/generation.rs @@ -20,7 +20,7 @@ use futures::{StreamExt, TryStreamExt}; use hyper::HeaderMap; use tracing::{debug, instrument}; -use super::{BoxStream, Client, Error, NlpClient, TgisClient}; +use super::{BoxStream, Client, Error, NlpClient, TgisClient, NlpClientHttp}; use crate::{ health::HealthCheckResult, models::{ @@ -47,6 +47,7 @@ pub struct GenerationClient(Option); enum GenerationClientInner { Tgis(TgisClient), Nlp(NlpClient), + NlpHttp(NlpHttpClient), } #[cfg_attr(test, faux::methods)] @@ -59,6 +60,10 @@ impl GenerationClient { Self(Some(GenerationClientInner::Nlp(client))) } + pub fn nlp_http(client: NlpClientHttp) -> Self { + Self(Some(GenerationClientInner::NlpHttp(client))) + } + pub fn not_configured() -> Self { Self(None) } @@ -99,6 +104,20 @@ impl GenerationClient { .collect::>(); Ok((response.token_count as u32, tokens)) } + Some(GenerationClientInner::NlpHttp(client)) => { + let request = TokenizationTaskRequest { text }; + debug!(provider = "nlp-http", ?request, "sending tokenize request"); + let response = client + .tokenization_task_predict(&model_id, request, headers) + .await?; + debug!(provider = "nlp-http", ?response, "received tokenize response"); + let tokens = response + .results + .into_iter() + .map(|token| token.text) + .collect::>(); + Ok((response.token_count as u32, tokens)) + } None => Err(Error::ModelNotFound { model_id }), } } @@ -164,6 +183,46 @@ impl GenerationClient { debug!(provider = "nlp", ?response, "received generate response"); Ok(response.into()) } + Some(GenerationClientInner::NlpHttp(client)) => { + let request = if let Some(params) = params { + TextGenerationTaskRequest { + text, + max_new_tokens: params.max_new_tokens.map(|v| v as i64), + min_new_tokens: params.min_new_tokens.map(|v| v as i64), + truncate_input_tokens: params.truncate_input_tokens.map(|v| v as i64), + decoding_method: params.decoding_method, + top_k: params.top_k.map(|v| v as i64), + top_p: params.top_p, + typical_p: params.typical_p, + temperature: params.temperature, + repetition_penalty: params.repetition_penalty, + max_time: params.max_time, + exponential_decay_length_penalty: params + .exponential_decay_length_penalty + .map(Into::into), + stop_sequences: params.stop_sequences.unwrap_or_default(), + seed: params.seed.map(|v| v as u64), + preserve_input_text: params.preserve_input_text, + input_tokens: params.input_tokens, + generated_tokens: params.generated_tokens, + token_logprobs: params.token_logprobs, + token_ranks: params.token_ranks, + include_stop_sequence: params.include_stop_sequence, + } + } else { + TextGenerationTaskRequest { + text, + ..Default::default() + } + } + debug!(provider = "nlp-http", ?request, "sending generate request"); + let response = client + .text_generation_task_predict(&model_id, request, headers) + .await?; + debug!(provider = "nlp-http", ?response, "received generate response"); + Ok(response.into()) + } + }; None => Err(Error::ModelNotFound { model_id }), } } @@ -241,10 +300,53 @@ impl GenerationClient { .boxed(); Ok(response_stream) } + Some(GenerationClientInner::NlpHttp(client)) => { +` let request = if let Some(params) = params { + ServerStreamingTextGenerationTaskRequest{ + text, + max_new_tokens: params.max_new_tokens.map(|v| v as i64), + min_new_tokens: params.min_new_tokens.map(|v| v as i64), + truncate_input_tokens: params.truncate_input_tokens.map(|v| v as i64), + decoding_method: params.decoding_method, + top_k: params.top_k.map(|v| v as i64), + top_p: params.top_p, + typical_p: params.typical_p, + temperature: params.temperature, + repetition_penalty: params.repetition_penalty, + max_time: params.max_time, + exponential_decay_length_penalty: params + .exponential_decay_length_penalty + .map(Into::into), + stop_sequences: params.stop_sequences.unwrap_or_default(), + seed: params.seed.map(|v| v as u64), + preserve_input_text: params.preserve_input_text, + input_tokens: params.input_tokens, + generated_tokens: params.generated_tokens, + token_logprobs: params.token_logprobs, + token_ranks: params.token_ranks, + include_stop_sequence: params.include_stop_sequence, + } + } else { + ServerStreamingTextGenerationTaskRequest { + text, + ..Default::default() + } + }; + debug!( + provider = "nlp-http", + ?request, + "sending generate_stream request" + ); + let response_stream = client + .server_streaming_text_generation_task_predict(&model_id, request, headers) + .await? + .map_ok(Into::into) + .boxed(); + Ok(response_stream) + } None => Err(Error::ModelNotFound { model_id }), } } -} #[cfg_attr(test, faux::methods)] #[async_trait] @@ -257,6 +359,7 @@ impl Client for GenerationClient { match &self.0 { Some(GenerationClientInner::Tgis(client)) => client.health().await, Some(GenerationClientInner::Nlp(client)) => client.health().await, + Some(GenerationClientInner::NlpHttp(client)) => client.health().await, None => unimplemented!(), } } diff --git a/src/clients/nlp_http.rs b/src/clients/nlp_http.rs new file mode 100644 index 00000000..7b4c3645 --- /dev/null +++ b/src/clients/nlp_http.rs @@ -0,0 +1,184 @@ +use async_trait::async_trait; +use axum::extract::Extension; +use tracing::{info, instrument}; +use hyper::{HeaderMap, StatusCode}; +use tracing::{info, instrument}; + +use super::{ + create_http_client, Client, Error, HttpClient +}; +use crate::{ + config::ServiceConfig, + health::HealthCheckResult, + pb::{ + caikit::runtime::nlp::{ + nlp_service_client::NlpServiceClient, ServerStreamingTextGenerationTaskRequest, + TextGenerationTaskRequest, TokenClassificationTaskRequest, TokenizationTaskRequest, + }, + caikit_data_model::nlp::{ + GeneratedTextResult, GeneratedTextStreamResult, TokenClassificationResults, + TokenizationResults, + }, + }, + tracing_utils::trace_context_from_http_response +}; + +const DEFAULT_PORT: u16 = 8085; +const MODEL_ID_HEADER_NAME: &str = "mm-model-id"; + +#[cfg_attr(test, faux::create)] +#[derive(Clone)] +pub struct NlpClientHttp { + client: HttpClient, + health_client: Option, +} + +#[cfg_attr(test, faux::methods)] +impl NlpClientHttp { + pub async fn new(config: &ServiceConfig) -> Self { + let client = create_http_client(DEFAULT_PORT, config); + let health_client = if let Some(health_config) = health_config { + Some(create_http_client(DEFAULT_PORT, health_config).await); + } else { + None + }; + Self { + client, + health_client, + } + } + + #[instrument(skip_all, fields(request.model))] + pub async tokenization_task_predict( + &self, + request: caikit::runtime::nlp::TokenizationTaskRequest, + headers: HeaderMap, + ) -> Result { + let url = self.client.base_url().join("/api/v1/task/tokenization").unwrap(); + let headers = with_traceparent_header(headers); + let request - request_with_headers(request, headers); + info!(?request, "sending request to NLP http service"); + let response = self + .client + .post(url) + .headers(headers) + .json(&request) + .send() + .await?; + match response.status() { + StatusCode::OK => OK(response.json().await?.into()), + _ => { + let code = response.status(); + let message = if let Ok(response) = response.json().await() { + response.message + } else { + "unknown error occured".into() + }; + Err(Error::Http {code, error}) + } + } + } + + pub async token_classification_task_predict( + &self, + request: TokenClassificationTaskRequest. + headers: HeaderMap, + ) -> Result { + let url = self.client.base_url().join("/api/v1/task/token-classification").unwrap(); + let headers = with_traceparent_header(headers); + info!(?request, "sending request to NLP http service"); + let response = self + .client + .post(url) + .headers(headers) + .json(&request) + .send() + .await?; + match response.status() { + StatusCode::OK => Ok(response.json::().await?.into()), + _ => { + let code = response.status(); + let message = if let Ok(response) = response.json::().await { + response.message + } else { + "unknown error occured".into() + }; + Err(Error::Http { code, message }) + } + } + + } + + pub async text_generation_task_predict( + &self, + request: TextGenerationTaskRequest, + headers: HeaderMap, + ) -> Result { + let url = self.client.base_url().join("/api/v1/task/text-generation").unwrap(); + let headers = with_traceparent_header(headers); + info!(?request, "sending request to NLP http service"); + let response = self + .client + .post(url) + .headers(headers) + .json(&request) + .send() + .await?; + match response.status() { + StatusCode::OK => Ok(response.json::().await?.into()), + _ => { + let code = response.status(); + let message = if let Ok(response) = response.json::().await { + response.message + } else { + "unknown error occured".into() + }; + Err(Error::Http { code, message }) + } + } + } + + pub async server_streaming_text_generation_task_predict( + &self, + request, ServerStreamingTextGenerationTaskRequest, + headers: HeaderMap, + ) -> Result { + let url = self.client.base_url().join("/api/v1/task/streaming-text-generation").unwrap(); + let headers = with_traceparent_header(headers); + info!(?request, "sending request to NLP http service"); + let response = self + .client + .post(url) + .headers(headers) + .json(&request) + .send() + .await?; + match response.status() { + StatusCode::OK => Ok(response.json::().await?.into()), + _ => { + let code = response.status(); + let message = if let Ok(response) = response.json::().await { + response.message + } else { + "unknown error occured".into() + }; + Err(Error::Http { code, message }) + } + } + } +} + +#[cfg_attr(test, faux::create)] +#[async_trait] +impl Client for NlpClientHttp { + fn name(&self) -> &str { + "nlp_http" + } + async fn health(&self) -> HealthCheckResult { + if let Some(health_client) = &self.health_client { + health_client.health().await + } else { + self.client.health().await + } + } +} \ No newline at end of file diff --git a/src/config.rs b/src/config.rs index 14053db1..25e552a8 100644 --- a/src/config.rs +++ b/src/config.rs @@ -94,6 +94,8 @@ pub enum GenerationProvider { Tgis, #[serde(rename = "nlp")] Nlp, + #[serde(rename = "nlp-http")] + NlpHttp } /// Generation service configuration diff --git a/src/orchestrator.rs b/src/orchestrator.rs index a7fa465d..a6df6ce0 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -37,7 +37,7 @@ use crate::{ TextGenerationDetectorClient, }, openai::{ChatCompletionsRequest, OpenAiClient}, - ClientMap, GenerationClient, NlpClient, TextContentsDetectorClient, TgisClient, + ClientMap, GenerationClient, NlpClient, TextContentsDetectorClient, TgisClient, NlpHttpClient }, config::{DetectorType, GenerationProvider, OrchestratorConfig}, health::HealthCheckCache, @@ -181,6 +181,11 @@ async fn create_clients(config: &OrchestratorConfig) -> ClientMap { let generation_client = GenerationClient::nlp(nlp_client); clients.insert("generation".to_string(), generation_client); } + GenerationProvider::NlpClientHttp => { + let nlp_client_http = NlpClientHttp::new(&generation.service).await; + let generation_client = GenerationClient::nlp_http(nlp_client_http); + clients.insert("generation".to_string(), generation_client); + } } }