diff --git a/Cargo.lock b/Cargo.lock
index 72ede252..4bbb6417 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -555,6 +555,17 @@ dependencies = [
"windows-sys 0.52.0",
]
+[[package]]
+name = "eventsource-stream"
+version = "0.2.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
+dependencies = [
+ "futures-core",
+ "nom",
+ "pin-project-lite",
+]
+
[[package]]
name = "fastrand"
version = "2.1.1"
@@ -604,11 +615,17 @@ dependencies = [
"faux",
"futures",
"ginepro",
+ "http-serde",
"hyper",
"hyper-util",
"mio",
+ "opentelemetry",
+ "opentelemetry-http",
+ "opentelemetry-otlp",
+ "opentelemetry_sdk",
"prost",
"reqwest",
+ "reqwest-eventsource",
"rustls",
"rustls-pemfile",
"rustls-webpki",
@@ -621,8 +638,10 @@ dependencies = [
"tokio-stream",
"tonic",
"tonic-build",
+ "tower-http",
"tower-service",
"tracing",
+ "tracing-opentelemetry",
"tracing-subscriber",
"tracing-test",
"url",
@@ -736,6 +755,12 @@ version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004"
+[[package]]
+name = "futures-timer"
+version = "3.0.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
+
[[package]]
name = "futures-util"
version = "0.3.30"
@@ -902,6 +927,16 @@ dependencies = [
"pin-project-lite",
]
+[[package]]
+name = "http-serde"
+version = "2.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0f056c8559e3757392c8d091e796416e4649d8e49e88b8d76df6c002f05027fd"
+dependencies = [
+ "http 1.1.0",
+ "serde",
+]
+
[[package]]
name = "httparse"
version = "1.9.5"
@@ -1372,6 +1407,85 @@ dependencies = [
"vcpkg",
]
+[[package]]
+name = "opentelemetry"
+version = "0.24.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4c365a63eec4f55b7efeceb724f1336f26a9cf3427b70e59e2cd2a5b947fba96"
+dependencies = [
+ "futures-core",
+ "futures-sink",
+ "js-sys",
+ "once_cell",
+ "pin-project-lite",
+ "thiserror",
+]
+
+[[package]]
+name = "opentelemetry-http"
+version = "0.13.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ad31e9de44ee3538fb9d64fe3376c1362f406162434609e79aea2a41a0af78ab"
+dependencies = [
+ "async-trait",
+ "bytes",
+ "http 1.1.0",
+ "opentelemetry",
+ "reqwest",
+]
+
+[[package]]
+name = "opentelemetry-otlp"
+version = "0.17.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6b925a602ffb916fb7421276b86756027b37ee708f9dce2dbdcc51739f07e727"
+dependencies = [
+ "async-trait",
+ "futures-core",
+ "http 1.1.0",
+ "opentelemetry",
+ "opentelemetry-http",
+ "opentelemetry-proto",
+ "opentelemetry_sdk",
+ "prost",
+ "thiserror",
+ "tokio",
+ "tonic",
+]
+
+[[package]]
+name = "opentelemetry-proto"
+version = "0.7.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "30ee9f20bff9c984511a02f082dc8ede839e4a9bf15cc2487c8d6fea5ad850d9"
+dependencies = [
+ "opentelemetry",
+ "opentelemetry_sdk",
+ "prost",
+ "tonic",
+]
+
+[[package]]
+name = "opentelemetry_sdk"
+version = "0.24.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "692eac490ec80f24a17828d49b40b60f5aeaccdfe6a503f939713afd22bc28df"
+dependencies = [
+ "async-trait",
+ "futures-channel",
+ "futures-executor",
+ "futures-util",
+ "glob",
+ "once_cell",
+ "opentelemetry",
+ "percent-encoding",
+ "rand",
+ "serde_json",
+ "thiserror",
+ "tokio",
+ "tokio-stream",
+]
+
[[package]]
name = "overload"
version = "0.1.1"
@@ -1750,15 +1864,33 @@ dependencies = [
"tokio",
"tokio-native-tls",
"tokio-rustls",
+ "tokio-util",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
+ "wasm-streams",
"web-sys",
"webpki-roots",
"windows-registry",
]
+[[package]]
+name = "reqwest-eventsource"
+version = "0.6.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
+dependencies = [
+ "eventsource-stream",
+ "futures-core",
+ "futures-timer",
+ "mime",
+ "nom",
+ "pin-project-lite",
+ "reqwest",
+ "thiserror",
+]
+
[[package]]
name = "reserve-port"
version = "2.0.1"
@@ -2380,6 +2512,23 @@ dependencies = [
"tracing",
]
+[[package]]
+name = "tower-http"
+version = "0.5.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5"
+dependencies = [
+ "bitflags",
+ "bytes",
+ "http 1.1.0",
+ "http-body",
+ "http-body-util",
+ "pin-project-lite",
+ "tower-layer",
+ "tower-service",
+ "tracing",
+]
+
[[package]]
name = "tower-layer"
version = "0.3.3"
@@ -2436,6 +2585,24 @@ dependencies = [
"tracing-core",
]
+[[package]]
+name = "tracing-opentelemetry"
+version = "0.25.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a9784ed4da7d921bc8df6963f8c80a0e4ce34ba6ba76668acadd3edbd985ff3b"
+dependencies = [
+ "js-sys",
+ "once_cell",
+ "opentelemetry",
+ "opentelemetry_sdk",
+ "smallvec",
+ "tracing",
+ "tracing-core",
+ "tracing-log",
+ "tracing-subscriber",
+ "web-time",
+]
+
[[package]]
name = "tracing-serde"
version = "0.1.3"
@@ -2703,6 +2870,19 @@ version = "0.2.93"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484"
+[[package]]
+name = "wasm-streams"
+version = "0.4.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4e072d4e72f700fb3443d8fe94a39315df013eef1104903cdb0a2abd322bbecd"
+dependencies = [
+ "futures-util",
+ "js-sys",
+ "wasm-bindgen",
+ "wasm-bindgen-futures",
+ "web-sys",
+]
+
[[package]]
name = "web-sys"
version = "0.3.70"
@@ -2713,6 +2893,16 @@ dependencies = [
"wasm-bindgen",
]
+[[package]]
+name = "web-time"
+version = "1.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
+dependencies = [
+ "js-sys",
+ "wasm-bindgen",
+]
+
[[package]]
name = "webpki-roots"
version = "0.26.6"
diff --git a/Cargo.toml b/Cargo.toml
index b4ae3e1e..310c2c62 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -15,16 +15,24 @@ path = "src/main.rs"
[dependencies]
anyhow = "1.0.86"
+async-trait = "0.1.81"
+async-stream = "0.3.5"
axum = { version = "0.7.5", features = ["json"] }
axum-extra = "0.9.3"
clap = { version = "4.5.15", features = ["derive", "env"] }
futures = "0.3.30"
ginepro = "0.8.1"
+http-serde = "2.1.1"
hyper = { version = "1.4.1", features = ["http1", "http2", "server"] }
hyper-util = { version = "0.1.7", features = ["server-auto", "server-graceful", "tokio"] }
mio = "1.0.2"
+opentelemetry = { version = "0.24.0", features = ["trace", "metrics"] }
+opentelemetry-http = { version = "0.13.0", features = ["reqwest"] }
+opentelemetry-otlp = { version = "0.17.0", features = ["http-proto"] }
+opentelemetry_sdk = { version = "0.24.1", features = ["rt-tokio", "metrics"] }
prost = "0.13.1"
reqwest = { version = "0.12.5", features = ["blocking", "rustls-tls", "json"] }
+reqwest-eventsource = "0.6.0"
rustls = {version = "0.23.12", default-features = false, features = ["std"]}
rustls-pemfile = "2.1.3"
rustls-webpki = "0.102.6"
@@ -36,13 +44,13 @@ tokio = { version = "1.39.2", features = ["rt", "rt-multi-thread", "parking_lot"
tokio-rustls = { version = "0.26.0" }
tokio-stream = { version = "0.1.15", features = ["sync"] }
tonic = { version = "0.12.1", features = ["tls", "tls-roots", "tls-webpki-roots"] }
+tower-http = { version = "0.5.2", features = ["trace"] }
tower-service = "0.3"
tracing = "0.1.40"
+tracing-opentelemetry = "0.25.0"
tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
url = "2.5.2"
uuid = { version = "1.10.0", features = ["v4", "fast-rng"] }
-async-trait = "0.1.81"
-async-stream = "0.3.5"
[build-dependencies]
tonic-build = "0.12.1"
diff --git a/Dockerfile b/Dockerfile
index 6caee35c..d1913a4c 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,6 +1,7 @@
ARG UBI_MINIMAL_BASE_IMAGE=registry.access.redhat.com/ubi9/ubi-minimal
ARG UBI_BASE_IMAGE_TAG=latest
ARG PROTOC_VERSION=26.0
+ARG CONFIG_FILE=config/config.yaml
## Rust builder ################################################################
# Specific debian version so that compatible glibc version is used
@@ -23,10 +24,10 @@ RUN rustup component add rustfmt
## Orchestrator builder #########################################################
FROM rust-builder as fms-guardrails-orchestr8-builder
-COPY build.rs *.toml LICENSE /app
-COPY config/ /app/config
-COPY protos/ /app/protos
-COPY src/ /app/src
+COPY build.rs *.toml LICENSE /app/
+COPY ${CONFIG_FILE} /app/config/config.yaml
+COPY protos/ /app/protos/
+COPY src/ /app/src/
WORKDIR /app
@@ -50,7 +51,7 @@ RUN cargo fmt --check
FROM ${UBI_MINIMAL_BASE_IMAGE}:${UBI_BASE_IMAGE_TAG} as fms-guardrails-orchestr8-release
COPY --from=fms-guardrails-orchestr8-builder /app/bin/ /app/bin/
-COPY config /app/config
+COPY ${CONFIG_FILE} /app/config/config.yaml
RUN microdnf install -y --disableplugin=subscription-manager shadow-utils compat-openssl11 && \
microdnf clean all --disableplugin=subscription-manager
diff --git a/config/config.yaml b/config/config.yaml
index 2fe6a657..ef5be2c4 100644
--- a/config/config.yaml
+++ b/config/config.yaml
@@ -11,6 +11,12 @@ generation:
service:
hostname: localhost
port: 8033
+# Generation server used for chat endpoints
+# chat_generation:
+# service:
+# hostname: localhost
+# port: 8080
+# # health_service:
# Any chunker servers that will be used by any detectors
chunkers:
# Chunker ID/name
@@ -26,11 +32,16 @@ chunkers:
detectors:
# Detector ID/name to be used in user requests
hap-en:
+ # Detector type (text_contents, text_generation, text_chat, text_context_doc)
+ type: text_contents
service:
- hostname: https://localhost/api/v1/text/contents # Full url / endpoint currently expected
+ hostname: localhost
port: 8080
# TLS ID/name, optional (detailed in `tls` section)
tls: detector
+ health_service:
+ hostname: localhost
+ port: 8081
# Chunker ID/name from `chunkers` section if applicable
chunker_id: en_regex
# Default score threshold for a detector. If a user
@@ -53,8 +64,7 @@ tls:
detector_bundle_no_ca:
cert_path: /path/to/client-bundle.pem
insecure: true
-
# Following section can be used to configure the allowed headers that orchestrator will pass to
# NLP provider and detectors. Note that, this section takes header keys, not values.
# passthrough_headers:
-# - header-key
\ No newline at end of file
+# - header-key
diff --git a/config/test.config.yaml b/config/test.config.yaml
index f32a0a26..cdde1905 100644
--- a/config/test.config.yaml
+++ b/config/test.config.yaml
@@ -3,6 +3,10 @@ generation:
service:
hostname: localhost
port: 443
+# chat_generation:
+# service:
+# hostname: localhost
+# port: 8080
chunkers:
test_chunker:
type: sentence
@@ -11,8 +15,9 @@ chunkers:
port: 8085
detectors:
test_detector:
+ type: text_contents
service:
- hostname: https://localhost/api/v1/text/contents
+ hostname: localhost
port: 8000
chunker_id: test_chunker
- default_threshold: 0.5
\ No newline at end of file
+ default_threshold: 0.5
diff --git a/docs/api/openapi_detector_api.yaml b/docs/api/openapi_detector_api.yaml
index 9bd96f0d..fb64ea23 100644
--- a/docs/api/openapi_detector_api.yaml
+++ b/docs/api/openapi_detector_api.yaml
@@ -5,9 +5,14 @@ info:
name: Apache 2.0
url: https://www.apache.org/licenses/LICENSE-2.0.html
version: 0.0.1
+tags:
+ - name: Text
+ description: Detections on text
paths:
/api/v1/text/contents:
post:
+ tags:
+ - Text
summary: Text Content Analysis Unary Handler
description: >-
Detectors that work on content text, be it user prompt or generated
@@ -67,6 +72,8 @@ paths:
$ref: '#/components/schemas/Error'
/api/v1/text/generation:
post:
+ tags:
+ - Text
summary: Generation Analysis Unary Handler
description: >-
Detectors that run on prompt and text generation output.
@@ -115,6 +122,8 @@ paths:
$ref: '#/components/schemas/Error'
/api/v1/text/chat:
post:
+ tags:
+ - Text
summary: Chat Analysis Unary Handler
description: >-
Detectors that analyze chat messages and provide detections
@@ -162,6 +171,8 @@ paths:
$ref: '#/components/schemas/Error'
/api/v1/text/context/doc:
post:
+ tags:
+ - Text
summary: Context Analysis Unary Handler
description: >-
Detectors that work on a context created by document(s).
diff --git a/docs/api/orchestrator_openapi_0_1_0.yaml b/docs/api/orchestrator_openapi_0_1_0.yaml
index d7196d3e..cd27db1d 100644
--- a/docs/api/orchestrator_openapi_0_1_0.yaml
+++ b/docs/api/orchestrator_openapi_0_1_0.yaml
@@ -3,12 +3,12 @@ info:
title: FMS Orchestrator API
version: 0.1.0
tags:
- - name: Task - Text Generation, with detection
- description: Detections on text generation model input and/or output
- - name: Task - Detection
- description: Standalone detections
- - name: Task - Chat Completions, with detection
- description: Detections on list of messages comprising a conversation and/or completions from a model
+ - name: Task - Text Generation, with detection
+ description: Detections on text generation model input and/or output
+ - name: Task - Detection
+ description: Standalone detections
+ - name: Task - Chat Completions, with detection
+ description: Detections on list of messages comprising a conversation and/or completions from a model
paths:
/health:
get:
@@ -17,7 +17,7 @@ paths:
summary: Performs quick liveliness check of the orchestrator service
operationId: health
responses:
- '200':
+ "200":
description: Healthy
content:
application/json:
@@ -42,18 +42,18 @@ paths:
type: boolean
default: false
responses:
- '200':
+ "200":
description: Orchestrator successfully probed client health statuses
content:
application/json:
schema:
- $ref: '#/components/schemas/HealthProbeResponse'
- '503':
+ $ref: "#/components/schemas/InfoResponse"
+ "503":
description: Orchestrator failed to probe client health statuses
content:
application/json:
schema:
- $ref: '#/components/schemas/HealthProbeResponse'
+ $ref: "#/components/schemas/InfoResponse"
/api/v1/task/classification-with-text-generation:
post:
tags:
@@ -65,27 +65,27 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/GuardrailsHttpRequest'
+ $ref: "#/components/schemas/GuardrailsHttpRequest"
required: true
responses:
- '200':
+ "200":
description: Successful Response
content:
application/json:
schema:
- $ref: '#/components/schemas/ClassifiedGeneratedTextResult'
- '404':
+ $ref: "#/components/schemas/ClassifiedGeneratedTextResult"
+ "404":
description: Resource Not Found
content:
application/json:
schema:
- $ref: '#/components/schemas/Error'
- '422':
+ $ref: "#/components/schemas/Error"
+ "422":
description: Validation Error
content:
application/json:
schema:
- $ref: '#/components/schemas/Error'
+ $ref: "#/components/schemas/Error"
/api/v1/task/server-streaming-classification-with-text-generation:
post:
tags:
@@ -97,27 +97,27 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/GuardrailsHttpRequest'
+ $ref: "#/components/schemas/GuardrailsHttpRequest"
required: true
responses:
- '200':
+ "200":
description: Successful Response
content:
- application/json:
+ text/event-stream:
schema:
- $ref: '#/components/schemas/ClassifiedGeneratedTextStreamResult'
- '404':
+ $ref: "#/components/schemas/ClassifiedGeneratedTextStreamResult"
+ "404":
description: Resource Not Found
content:
- application/json:
+ text/event-stream:
schema:
- $ref: '#/components/schemas/Error'
- '422':
+ $ref: "#/components/schemas/Error"
+ "422":
description: Validation Error
content:
- application/json:
+ text/event-stream:
schema:
- $ref: '#/components/schemas/Error'
+ $ref: "#/components/schemas/Error"
/api/v2/text/generation-detection:
post:
tags:
@@ -129,27 +129,27 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/GenerationDetectionRequest'
+ $ref: "#/components/schemas/GenerationDetectionRequest"
required: true
responses:
- '200':
+ "200":
description: Successful Response
content:
application/json:
schema:
- $ref: '#/components/schemas/GenerationDetectionResponse'
- '404':
+ $ref: "#/components/schemas/GenerationDetectionResponse"
+ "404":
description: Resource Not Found
content:
application/json:
schema:
- $ref: '#/components/schemas/Error'
- '422':
+ $ref: "#/components/schemas/Error"
+ "422":
description: Validation Error
content:
application/json:
schema:
- $ref: '#/components/schemas/Error'
+ $ref: "#/components/schemas/Error"
/api/v2/text/detection/content:
post:
@@ -162,27 +162,77 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/DetectionContentRequest'
+ $ref: "#/components/schemas/DetectionContentRequest"
required: true
responses:
- '200':
+ "200":
description: Successful Response
content:
application/json:
schema:
- $ref: '#/components/schemas/DetectionContentResponse'
- '404':
+ $ref: "#/components/schemas/DetectionContentResponse"
+ "404":
description: Resource Not Found
content:
application/json:
schema:
- $ref: '#/components/schemas/Error'
- '422':
+ $ref: "#/components/schemas/Error"
+ "422":
description: Validation Error
content:
application/json:
schema:
- $ref: '#/components/schemas/Error'
+ $ref: "#/components/schemas/Error"
+ /api/v2/text/detection/stream-content:
+ post:
+ tags:
+ - Task - Detection
+ summary: Detection task on input content stream
+ operationId: >-
+ api_v2_detection_text_content_bidi_stream_handler
+ requestBody:
+ content:
+ application/x-ndjson:
+ schema:
+ oneOf:
+ - $ref: "#/components/schemas/DetectionContentRequest"
+ - $ref: "#/components/schemas/DetectionContentStreamEvent"
+ # In OpenAPI 3.0, examples cannot be present in schemas,
+ # whereas object level examples are present in 3.1
+ examples:
+ first_event:
+ summary: First text event with detectors
+ value:
+ detectors:
+ hap-v1-model-en: {}
+ content: "my text here"
+ text:
+ summary: Regular text event
+ value:
+ content: "my text here"
+ required: true
+ responses:
+ "200":
+ description: Successful Response
+ content:
+ text/event-stream:
+ schema:
+ # NOTE: This endpoint, like the
+ # `server-streaming-classification-with-text-generation`
+ # endpoint will produce streamed events
+ $ref: "#/components/schemas/DetectionContentStreamResponse"
+ "404":
+ description: Resource Not Found
+ content:
+ text/event-stream:
+ schema:
+ $ref: "#/components/schemas/Error"
+ "422":
+ description: Validation Error
+ content:
+ text/event-stream:
+ schema:
+ $ref: "#/components/schemas/Error"
/api/v2/text/detection/chat:
post:
tags:
@@ -194,27 +244,27 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/DetectionChatRequest'
+ $ref: "#/components/schemas/DetectionChatRequest"
required: true
responses:
- '200':
+ "200":
description: Successful Response
content:
application/json:
schema:
- $ref: '#/components/schemas/DetectionChatResponse'
- '404':
+ $ref: "#/components/schemas/DetectionChatResponse"
+ "404":
description: Resource Not Found
content:
application/json:
schema:
- $ref: '#/components/schemas/Error'
- '422':
+ $ref: "#/components/schemas/Error"
+ "422":
description: Validation Error
content:
application/json:
schema:
- $ref: '#/components/schemas/Error'
+ $ref: "#/components/schemas/Error"
/api/v2/text/detection/context:
post:
@@ -227,27 +277,27 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/DetectionContextDocsRequest'
+ $ref: "#/components/schemas/DetectionContextDocsRequest"
required: true
responses:
- '200':
+ "200":
description: Successful Response
content:
application/json:
schema:
- $ref: '#/components/schemas/DetectionContextDocsResponse'
- '404':
+ $ref: "#/components/schemas/DetectionContextDocsResponse"
+ "404":
description: Resource Not Found
content:
application/json:
schema:
- $ref: '#/components/schemas/Error'
- '422':
+ $ref: "#/components/schemas/Error"
+ "422":
description: Validation Error
content:
application/json:
schema:
- $ref: '#/components/schemas/Error'
+ $ref: "#/components/schemas/Error"
/api/v2/text/detection/generated:
post:
tags:
@@ -259,126 +309,109 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/GeneratedTextDetectionRequest'
+ $ref: "#/components/schemas/GeneratedTextDetectionRequest"
required: true
responses:
- '200':
+ "200":
description: Successful Response
content:
application/json:
schema:
- $ref: '#/components/schemas/GeneratedTextDetectionResponse'
- '404':
+ $ref: "#/components/schemas/GeneratedTextDetectionResponse"
+ "404":
description: Resource Not Found
content:
application/json:
schema:
- $ref: '#/components/schemas/Error'
- '422':
+ $ref: "#/components/schemas/Error"
+ "422":
description: Validation Error
content:
application/json:
schema:
- $ref: '#/components/schemas/Error'
-
+ $ref: "#/components/schemas/Error"
+
/api/v2/chat/completions-detection:
post:
- tags:
- - Task - Chat Completions, with detection
- operationId: >-
- api_v2_chat_completions_detection_handler
- summary: Creates a model response with detections for the given chat conversation.
- requestBody:
- required: true
- content:
- application/json:
- schema:
- $ref: "#/components/schemas/GuardrailsCreateChatCompletionRequest"
- responses:
- '200':
- description: Successful Response
- content:
- application/json:
- schema:
- $ref: "#/components/schemas/GuardrailsCreateChatCompletionResponse"
- '404':
- description: Resource Not Found
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/Error'
- '422':
- description: Validation Error
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/Error'
+ tags:
+ - Task - Chat Completions, with detection
+ operationId: >-
+ api_v2_chat_completions_detection_handler
+ summary: Creates a model response with detections for the given chat conversation.
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: "#/components/schemas/GuardrailsCreateChatCompletionRequest"
+ responses:
+ "200":
+ description: Successful Response
+ content:
+ application/json:
+ schema:
+ $ref: "#/components/schemas/GuardrailsCreateChatCompletionResponse"
+ "404":
+ description: Resource Not Found
+ content:
+ application/json:
+ schema:
+ $ref: "#/components/schemas/Error"
+ "422":
+ description: Validation Error
+ content:
+ application/json:
+ schema:
+ $ref: "#/components/schemas/Error"
components:
schemas:
HealthStatus:
- type: string
- enum:
- - HEALTHY
- - UNHEALTHY
- - UNKNOWN
- title: Health Status
+ type: string
+ enum:
+ - HEALTHY
+ - UNHEALTHY
+ - UNKNOWN
+ title: Health Status
HealthCheckResult:
- oneOf:
- - properties:
- health_status:
- $ref: '#/components/schemas/HealthStatus'
- response_code:
- type: string
- title: Response Code
- example: "HTTP 200 OK"
- reason:
- type: string
- title: Reason
- example: "Service not found"
- required:
- - health_status
- - response_code
- - additionalProperties:
- $ref: '#/components/schemas/HealthStatus'
+ properties:
+ status:
+ $ref: "#/components/schemas/HealthStatus"
+ code:
+ type: string
+ title: Response Code
+ example: 200
+ reason:
+ type: string
+ title: Reason
+ example: "Not Found"
+ required:
+ - status
type: object
- title: Health Check Response
- HealthProbeResponse:
+ title: Health Check Result
+ InfoResponse:
properties:
services:
type: object
title: Health status for each client service
- properties:
- generation:
- type: object
- title: Generation Services
- items:
- $ref: '#/components/schemas/HealthCheckResult'
- detectors:
- type: object
- title: Detector Services
- items:
- $ref: '#/components/schemas/HealthCheckResult'
- chunkers:
- type: object
- title: Chunker Services
- items:
- $ref: '#/components/schemas/HealthCheckResult'
+ items:
+ $ref: "#/components/schemas/HealthCheckResult"
required:
- services
type: object
- title: Health Probe Response
+ title: Info Response
DetectionContentRequest:
properties:
detectors:
type: object
title: Detectors
default: {}
- example:
+ example:
hap-v1-model-en: {}
content:
type: string
title: Content
+ example: "my text here"
required: ["detectors", "content"]
additionalProperties: false
type: object
@@ -388,7 +421,7 @@ components:
detections:
type: array
items:
- $ref: '#/components/schemas/DetectionContentResponseObject'
+ $ref: "#/components/schemas/DetectionContentResponseObject"
additionalProperties: false
required: ["detections"]
type: object
@@ -415,15 +448,38 @@ components:
title: Score
title: Content Detection Response Object
example:
- - {
- "start": 0,
- "end": 20,
- "text": "string",
- "detection_type": "HAP",
- "detection": "has_HAP",
- "detector_id": "hap-v1-model-en",
- "score": 0.999
- }
+ start: 0
+ end: 20
+ text: "string"
+ detection_type: "HAP"
+ detection: "has_HAP"
+ detector_id: "hap-v1-model-en"
+ score: 0.999
+ DetectionContentStreamEvent:
+ properties:
+ content:
+ type: string
+ title: Content
+ example: "my text here"
+ required: ["content"]
+ type: object
+ description: Individual stream event
+ title: Content Detection Stream Event
+ DetectionContentStreamResponse:
+ properties:
+ detections:
+ type: array
+ items:
+ $ref: "#/components/schemas/DetectionContentResponseObject"
+ processed_index:
+ anyOf:
+ - type: integer
+ title: Processed Index
+ start_index:
+ type: integer
+ title: Start Index
+ type: object
+ title: Content Detection Stream Response
DetectionChatRequest:
properties:
@@ -431,7 +487,7 @@ components:
type: object
title: Detectors
default: {}
- example:
+ example:
chat-v1-model-en: {}
messages:
title: Chat Messages
@@ -465,14 +521,14 @@ components:
title: Detections on entire history of chat messages
title: Chat Detection Response
required: ["detections"]
-
+
DetectionContextDocsRequest:
properties:
detectors:
type: object
title: Detectors
default: {}
- example:
+ example:
context-v1-model-en: {}
content:
type: string
@@ -500,7 +556,7 @@ components:
detections:
type: array
items:
- $ref: '#/components/schemas/DetectionContextDocsResponseObject'
+ $ref: "#/components/schemas/DetectionContextDocsResponseObject"
required: ["detections"]
title: Context Docs Detection Response
DetectionContextDocsResponseObject:
@@ -516,9 +572,9 @@ components:
title: Score
evidence:
anyOf:
- - items:
- $ref: '#/components/schemas/EvidenceObj'
- type: array
+ - items:
+ $ref: "#/components/schemas/EvidenceObj"
+ type: array
title: Context Docs Detection Response Object
GenerationDetectionRequest:
@@ -533,11 +589,11 @@ components:
type: object
title: Detectors
default: {}
- example:
+ example:
generation-detection-v1-model-en: {}
text_gen_parameters:
allOf:
- - $ref: '#/components/schemas/GuardrailsTextGenerationParameters'
+ - $ref: "#/components/schemas/GuardrailsTextGenerationParameters"
type: object
required: ["model_id", "prompt", "detectors"]
title: Generation-Detection Request
@@ -566,7 +622,7 @@ components:
title: Input token Count
title: Generation Detection Response
required: ["generated_text", "detections"]
-
+
GeneratedTextDetectionRequest:
properties:
prompt:
@@ -579,7 +635,7 @@ components:
type: object
title: Detectors
default: {}
- example:
+ example:
generated-detection-v1-model-en: {}
type: object
required: ["generated_text", "prompt", "detectors"]
@@ -589,7 +645,7 @@ components:
detections:
type: array
items:
- $ref: '#/components/schemas/GeneratedTextDetectionResponseObject'
+ $ref: "#/components/schemas/GeneratedTextDetectionResponseObject"
required: ["detections"]
title: Generated Text Detection Response
GeneratedTextDetectionResponseObject:
@@ -614,10 +670,10 @@ components:
title: Generated Text
token_classification_results:
anyOf:
- - $ref: '#/components/schemas/TextGenTokenClassificationResults'
+ - $ref: "#/components/schemas/TextGenTokenClassificationResults"
finish_reason:
anyOf:
- - $ref: '#/components/schemas/FinishReason'
+ - $ref: "#/components/schemas/FinishReason"
generated_token_count:
anyOf:
- type: integer
@@ -633,19 +689,19 @@ components:
warnings:
anyOf:
- items:
- $ref: '#/components/schemas/InputWarning'
+ $ref: "#/components/schemas/InputWarning"
type: array
title: Warnings
tokens:
anyOf:
- items:
- $ref: '#/components/schemas/GeneratedToken'
+ $ref: "#/components/schemas/GeneratedToken"
type: array
title: Tokens
input_tokens:
anyOf:
- items:
- $ref: '#/components/schemas/GeneratedToken'
+ $ref: "#/components/schemas/GeneratedToken"
type: array
title: Input Tokens
additionalProperties: false
@@ -660,10 +716,10 @@ components:
title: Generated Text
token_classification_results:
anyOf:
- - $ref: '#/components/schemas/TextGenTokenClassificationResults'
+ - $ref: "#/components/schemas/TextGenTokenClassificationResults"
finish_reason:
anyOf:
- - $ref: '#/components/schemas/FinishReason'
+ - $ref: "#/components/schemas/FinishReason"
generated_token_count:
anyOf:
- type: integer
@@ -679,19 +735,19 @@ components:
warnings:
anyOf:
- items:
- $ref: '#/components/schemas/InputWarning'
+ $ref: "#/components/schemas/InputWarning"
type: array
title: Warnings
tokens:
anyOf:
- items:
- $ref: '#/components/schemas/GeneratedToken'
+ $ref: "#/components/schemas/GeneratedToken"
type: array
title: Tokens
input_tokens:
anyOf:
- items:
- $ref: '#/components/schemas/GeneratedToken'
+ $ref: "#/components/schemas/GeneratedToken"
type: array
title: Input Tokens
processed_index:
@@ -706,18 +762,18 @@ components:
type: object
title: Classified Generated Text Stream Result
TextGenTokenClassificationResults:
- # By default open-api spec consider all fields as optional
+ # By default open-api spec consider all fields as optional
properties:
input:
anyOf:
- items:
- $ref: '#/components/schemas/TokenClassificationResult'
+ $ref: "#/components/schemas/TokenClassificationResult"
type: array
title: Input
output:
anyOf:
- items:
- $ref: '#/components/schemas/TokenClassificationResult'
+ $ref: "#/components/schemas/TokenClassificationResult"
type: array
title: Output
additionalProperties: false
@@ -765,7 +821,7 @@ components:
default: {}
required:
- detectors
-
+
GuardrailsCreateChatCompletionResponse:
title: Guardrails Chat Completion Response
description: Guardrails chat completion response (adds detections on OpenAI chat completion)
@@ -778,7 +834,7 @@ components:
warnings:
type: array
items:
- $ref: '#/components/schemas/Warning'
+ $ref: "#/components/schemas/Warning"
required:
- detections
@@ -800,27 +856,27 @@ components:
output:
pii-v1: {}
conversation-detector: {}
-
+
ChatCompletionsDetections:
title: Chat Completions Detections
properties:
input:
type: array
items:
- $ref: '#/components/schemas/MessageDetections'
+ $ref: "#/components/schemas/MessageDetections"
title: Detections on input to chat completions
default: {}
output:
type: array
items:
- $ref: '#/components/schemas/ChoiceDetections'
+ $ref: "#/components/schemas/ChoiceDetections"
title: Detections on output of chat completions
default: {}
default: {}
example:
input:
- message_index: 0
- results:
+ results:
- {
"start": 0,
"end": 80,
@@ -828,7 +884,7 @@ components:
"detection_type": "HAP",
"detection": "has_HAP",
"detector_id": "hap-v1-model-en", # Future addition
- "score": 0.999
+ "score": 0.999,
}
output:
- choice_index: 0
@@ -841,15 +897,15 @@ components:
"detection_type": "HAP",
"detection": "has_HAP",
"detector_id": "hap-v1-model-en", # Future addition
- "score": 0.999
+ "score": 0.999,
}
- {
"detection_type": "string",
"detection": "string",
"detector_id": "relevance-v1-en", # Future addition
- "score": 0
+ "score": 0,
}
-
+
MessageDetections:
title: Message Detections
properties:
@@ -863,9 +919,9 @@ components:
type: array
items:
anyOf:
- - $ref: '#/components/schemas/DetectionContentResponseObject'
- - $ref: '#/components/schemas/DetectionContextDocsResponseObject'
- - $ref: '#/components/schemas/GeneratedTextDetectionResponseObject'
+ - $ref: "#/components/schemas/DetectionContentResponseObject"
+ - $ref: "#/components/schemas/DetectionContextDocsResponseObject"
+ - $ref: "#/components/schemas/GeneratedTextDetectionResponseObject"
required:
- message_index
ChoiceDetections:
@@ -881,9 +937,9 @@ components:
type: array
items:
anyOf:
- - $ref: '#/components/schemas/DetectionContentResponseObject'
- - $ref: '#/components/schemas/DetectionContextDocsResponseObject'
- - $ref: '#/components/schemas/GeneratedTextDetectionResponseObject'
+ - $ref: "#/components/schemas/DetectionContentResponseObject"
+ - $ref: "#/components/schemas/DetectionContextDocsResponseObject"
+ - $ref: "#/components/schemas/GeneratedTextDetectionResponseObject"
required:
- choice_index
@@ -948,7 +1004,7 @@ components:
evidence:
anyOf:
- items:
- $ref: '#/components/schemas/Evidence'
+ $ref: "#/components/schemas/Evidence"
type: array
type: object
required:
@@ -1010,7 +1066,7 @@ components:
title: Inputs
guardrail_config:
allOf:
- - $ref: '#/components/schemas/GuardrailsConfig'
+ - $ref: "#/components/schemas/GuardrailsConfig"
default:
input:
masks: []
@@ -1019,7 +1075,7 @@ components:
models: {}
text_gen_parameters:
allOf:
- - $ref: '#/components/schemas/GuardrailsTextGenerationParameters'
+ - $ref: "#/components/schemas/GuardrailsTextGenerationParameters"
type: object
required:
- model_id
@@ -1059,7 +1115,7 @@ components:
title: Max Time
exponential_decay_length_penalty:
allOf:
- - $ref: '#/components/schemas/ExponentialDecayLengthPenalty'
+ - $ref: "#/components/schemas/ExponentialDecayLengthPenalty"
stop_sequences:
items:
type: string
@@ -1094,7 +1150,7 @@ components:
properties:
id:
allOf:
- - $ref: '#/components/schemas/InputWarningReason'
+ - $ref: "#/components/schemas/InputWarningReason"
message:
type: string
title: Message
diff --git a/docs/architecture/adrs/005-chat-completion-support.md b/docs/architecture/adrs/005-chat-completion-support.md
index d04e92da..6fb5156a 100644
--- a/docs/architecture/adrs/005-chat-completion-support.md
+++ b/docs/architecture/adrs/005-chat-completion-support.md
@@ -73,4 +73,4 @@ This means that the orchestrator will have to be able to track chunking and dete
## Status
-Proposed
+Accepted
diff --git a/docs/architecture/adrs/006-detector-type.md b/docs/architecture/adrs/006-detector-type.md
new file mode 100644
index 00000000..99c4e077
--- /dev/null
+++ b/docs/architecture/adrs/006-detector-type.md
@@ -0,0 +1,45 @@
+# ADR 006: Detector Type
+
+This ADR documents the decision of adding the `type` parameter for detectors in the orchestrator config.
+
+## Motivation
+
+The guardrails orchestrator interfaces with different types of detectors.
+Detectors of a given type are compatible with only a subset of orchestrator endpoints.
+In order to reduce changes of misconfiguration, we need a way to map detectors to be used only with compatible endpoints. This would additionally provide a way for us to refer to a particular detector type within the code, without looking at its `hostname` (url) , which can be error prone. Good example for this is validating if certain detector would work with certain orchestrator endpoint or not.
+
+
+## Decision
+
+We decided to add the `type` parameter to the detectors configuration.
+Possible values are `text_contents`, `text_chat`, `text_generation` and `text_context_doc`.
+Below is an example of detector configuration.
+
+```yaml
+detectors:
+ my_detector:
+ type: text_contents # Options: text_contents, text_context_chat, text_context_doc, text_generation
+ service:
+ hostname: my-detector.com
+ port: 8080
+ tls: my_certs
+ chunker_id: my_chunker
+ default_threshold: 0.5
+```
+
+## Consequences
+
+1. Reduced misconfiguration risk.
+2. Future logic can be implemented for detectors of a particular type.
+3. `hostname` no longer needs the full URL, but only the actual hostname.
+4. If `tls` is provided, the `https` protocol is used. `http`, otherwise.
+5. Not including `type` results in a configuration validation error on orchestrator startup.
+6. Detector endpoints are automatically configured based on `type` as follows:
+ * `text_contents` -> `/api/v1/text/contents`
+ * `text_chat` -> `/api/v1/text/chat`
+ * `text_context_doc` -> `/api/v1/text/context/doc`
+ * `text_generation` -> `/api/v1/text/generation`
+
+## Status
+
+Accepted
\ No newline at end of file
diff --git a/docs/architecture/adrs/007-orchestrator-telemetry.md b/docs/architecture/adrs/007-orchestrator-telemetry.md
new file mode 100644
index 00000000..19faffac
--- /dev/null
+++ b/docs/architecture/adrs/007-orchestrator-telemetry.md
@@ -0,0 +1,96 @@
+# ADR 007: Orchestrator Telemetry
+
+The guardrails orchestrator uses [OpenTelemetry](https://opentelemetry.io/) to collect and export telemetry data (traces, metrics, and logs).
+
+The orchestrator needs to collect telemetry data for monitoring and observability. It also needs to be able to trace
+spans for incoming requests and across client requests to configured detectors, chunkers, and generation services and
+aggregate detailed traces, metrics, and logs that can be monitored from a variety of observability backends.
+
+## Decision
+
+### OpenTelemetry and `tracing`
+
+The orchestrator and client services will make use of the OpenTelemetry SDK and the [OpenTelemetry Protocol (OTLP)](https://opentelemetry.io/docs/specs/otel/protocol/)
+for consolidating and collecting telemetry data across services. The orchestrator will be responsible for collecting
+telemetry data throughout the lifetime of a request using the `tracing` crate, which is the de facto choice for logging
+and tracing for OpenTelemetry in Rust, and exporting it through the OTLP exporter if configured. The OTLP exporter will
+send telemetry data to a gRPC or HTTP endpoint that can be configured to point to a running OpenTelemetry (OTEL) collector.
+Similarly, detectors should also be able to collect and export telemetry through OTLP to the same OTEL collector.
+From the OTEL collector, the telemetry data can then be exported to multiple backends. The OTEL collector and
+any observability backends can all be configured alongside the orchestrator and detectors in a deployment.
+
+### Instrumentation
+An incoming request to the orchestrator will initialize a new trace, therefore a trace-id and request should be in
+one-to-one correspondence. All important functions throughout the control flow of handling a request in the orchestrator
+will be instrumented with the `#[tracing::instrument]` attribute macro above the function definition. This will create
+and enter a span for each function call and add it to the trace of the request. Here, important functions refers to any
+functions that perform important business logic that may incur significant latency, including all the handler functions
+for incoming and outgoing requests. It is up to the discretion of the developer to determine what functions are
+"significant" enough to indicate a new span in the trace, but adding a new tracing span can always trivially be done by
+just adding the instrument macro.
+
+### Metrics
+The orchestrator will aggregate metrics regarding the requests it has received/handled, and annotate the metrics with
+span attributes allowing for detailed filtering and monitoring. The metrics will be exported through the OTLP exporter
+through the metrics provider. Traces exported through the traces provider can also have R.E.D. (request, error and
+duration) metrics attached to them implicitly by the OTEL collector using the `spanmetrics` connector. Both the OTLP
+metrics and the `spanmetrics` metrics can be exported to configured metrics backends like Prometheus or Grafana.
+The orchestrator will handle a variety of useful metrics such as counters and histograms for received/handled
+successful/failed requests, request and stream durations, and server/client errors. Traces and metrics will also relate
+incoming orchestrator requests to respective client requests/responses, and collect more business specific metrics
+e.g. regarding the outcome of running detection or generation.
+
+### Configuration
+The orchestrator will expose CLI args/env variables for configuring the OTLP exporter:
+- `OTEL_EXPORTER_OTLP_PROTOCOL=grpc|http` to set the protocol for all the OTLP endpoints
+ - `OTEL_EXPORTER_OTLP_TRACES_PROTOCOL` and `OTEL_EXPORTER_OTLP_METRICS_PROTOCOL` to set/override the protocol for
+ traces or metrics.
+- `--otlp-endpoint, OTEL_EXPORTER_OTLP_ENDPOINT` to set the OTLP endpoint
+ - defaults: gRPC `localhost:4317` and HTTP `localhost:4318`
+ - `--otlp-traces-endpoint, OTEL_EXPORTER_OTLP_TRACES_ENDPOINT` and `--otlp-metrics-endpoint,
+ OTEL_EXPORTER_OTLP_METRICS_ENDPOINT` to set/override the endpoint for traces or metrics
+ - default to `localhost:4317` for gRPC for all data types, and `localhost:4318/v1/traces`, or `metrics`, for HTTP
+- `--otlp-export, OTLP_EXPORT=traces,metrics` to specify a list of which data types to export to the OTLP exporters, separated by a
+ comma. The possible values are traces, metrics, or both. The OTLP standard specifies three data types (`traces`,
+ `metrics`, `logs`) but since we use the recommended `tracing` crate for logging, we can export logs as traces and
+ not use the separate (more experimental) logging export pipeline.
+- `RUST_LOG=error|warn|info|debug|trace` to set the Rust log level.
+- `--log-format, LOG_FORMAT=full|compact|json|pretty` to set the logging format for logs printed to stdout. All logs collected as
+ traces by OTLP will just be structured traces, this argument is specifically for stdout. Default is `full`.
+- `--quiet, -q` to silence logging to stdout. If `OTLP_EXPORT=traces` is still provided, all logs can still be viewed
+ as traces from an observability backend.
+
+### Cross-service tracing
+The orchestrator and client services will be able to consolidate telemetry and share observability through a common
+configuration and backends. This will be made possible through the use of the OTLP standard as well as through the
+propagation of the trace context through requests across services using the standardized `traceparent` header. The
+orchestrator will be expected to initialize a new trace for an incoming request and pass `traceparent` headers
+corresponding to this trace to any requests outgoing to clients, and similarly, the orchestrator will expect the client
+to provide a `traceparent` header in the response. The orchestrator will not propagate the `traceparent` to outgoing
+responses back to the end user (or expect `traceparent` in incoming requests) for security reasons.
+
+## Status
+
+Accepted
+
+## Consequences
+
+- The orchestrator and client services have a common standard to conform to for telemetry, allowing for traceability
+ across different services. There does not exist any other attempts at telemetry standardization that is as widely
+ accepted as OpenTelemetry, or have the same level of support from existing observability and monitoring services.
+- The deployment of the orchestrator must be configured with telemetry service(s) listening for telemetry exported on
+ the specified endpoint(s). An [OTEL collector](https://opentelemetry.io/docs/collector/) service can be used to
+ collect and propagate the telemetry data, or the export endpoint(s) can be listened to directly by any backend that
+ supports OTLP (e.g. Jaeger).
+- The orchestrator and client services do not need to be concerned with specific observability backends, the OTEL
+ collector and OTLP standard can be used to export telemetry data to a variety of backends including Jaeger,
+ Prometheus, Grafana, and Instana, as well to OpenShift natively through the OpenTelemetryCollector CRD.
+- Using the `tracing` crate in Rust for logging will treat logs as traces, allowing the orchestrator to export logs
+ through the trace provider (with OTLP exporter), simplifying the implementation and avoiding use of the logging
+ provider which is still considered experimental in many contexts (it exists for compatibility with non `tracing`
+ logging libraries).
+- For stdout, the new `--log-format` and `--quiet` arguments add more configurability to format or silence logging.
+- The integration of the OpenTelemetry API/SDK into the stack is not trivial, and the OpenTelemetry crates will incur
+ additional compile time to the orchestrator.
+- The OpenTelemetry API/SDK and OTLP standard are new and still evolving, and the orchestrator will need to keep up
+ with changes in the OpenTelemetry ecosystem, there could be occasional breaking changes that will need addressing.
\ No newline at end of file
diff --git a/docs/architecture/adrs/008-streaming-orchestrator-endpoints.md b/docs/architecture/adrs/008-streaming-orchestrator-endpoints.md
new file mode 100644
index 00000000..418c25ac
--- /dev/null
+++ b/docs/architecture/adrs/008-streaming-orchestrator-endpoints.md
@@ -0,0 +1,42 @@
+# ADR 008: Streaming orchestrator endpoints
+
+This ADR documents the patterns and behavior expected for streaming orchestrator endpoints.
+
+The orchestrator API can be found [at these github pages](https://foundation-model-stack.github.io/fms-guardrails-orchestrator/).
+
+## Motivation
+
+In [ADR 004](./004-orchestrator-input-only-api-design.md), the design of "input only" detection endpoints was detailed. Currently, those endpoints could only support the "unary" case, where the entire input text is available upfront. For flexibility (example: text is streamed from a generative model that may be available but uncallable through the endpoints with generation), users may still want to call detections on streamed input text.
+
+The orchestrator will then need to support "bidirectional streaming" endpoints, where text (whether tokens, words, sentences) is streamed in, detectors are invoked (and call their respective chunkers, using bidirectional streaming), and text processed with detectors including potential detections is streamed back to the user.
+
+
+## Decisions
+
+### Server streaming or endpoint output streaming
+"Server streaming" endpoints existed already prior to the writing of this particular ADR. Streaming response aggregation behavior is documented in [ADR 002](./002-streaming-response-aggregation.md). Data will continue to be streamed back with `data` events, with errors included as `event: error` per the [SSE event format](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format).
+
+Parameters in each response event such as `start_index` and `processed_index` will indicate to the user how much of the input stream has been processed for detections, as there might not necessarily be results like positive `detections` for certain portions of the input stream. The `start_index` and `processed_index` will be relative to the entire stream.
+
+### Client streaming or endpoint input streaming
+- Any information needed for an entire request, like `detectors` that any detection endpoints will work on, will be expected to be present in the first event of a stream. The structure of stream events expected will be documented for each endpoint.
+ - An alternate consideration was using query or path parameters for information needed for an entire request, like `detectors`, but this would be complicated for the nesting that `detectors` require currently, with a mapping of each detector to dictionary parameters.
+ - Another alternate consideration was expecting multipart requests, one part with information for the entire request like `detectors` and another part with individual stream events. However, here the content type accepted by the request would have to change.
+- Stream closing will be the expected indication that stream events have ended.
+ - An alternate consideration is an explicit "end of stream" request message for each endpoint, for the user to indicate the connection should be closed. For example for the OpenAI chat completions API, this looks like a `[DONE]` event. The downside here is that this particular event's contents will have to be identified and processed differently from other events.
+
+### Separate streaming detection endpoints
+
+To be clear to users, we will start with endpoints that indicate `stream` in the endpoint name. We want to avoid adding `stream` parameters in the request body since this will increase the maintenance of parameters on each request event in the streaming case. Additionally, as detailed earlier, stream endpoint responses will tend to have additional or potentially different fields than their unary counterparts. This point can be altered based on sufficient user feedback.
+
+NOTE: This ADR will not prescribe implementation details, but while the underlying implementation _could_ use [websockets](https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API), we are explicitly not following the patterns of some websocket APIs that require connecting and disconnecting.
+
+## Consequences
+
+- Stream detection endpoints will be separate from current "unary" ones that take entire inputs and return one response. Users then must change endpoints for this different use case.
+- The orchestrator can support input or client streaming in a consistent manner. This will enable orchestrator users that may want to stream input content from other sources, like their own generative model.
+- Users have to be aware that for input streaming, the first event may need to contain more information necessary for the endpoint. Thus the event message structure may not be exactly the same across events in the stream.
+
+## Status
+
+Proposed
diff --git a/src/args.rs b/src/args.rs
new file mode 100644
index 00000000..b5b38f03
--- /dev/null
+++ b/src/args.rs
@@ -0,0 +1,214 @@
+/*
+ Copyright FMS Guardrails Orchestrator Authors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+*/
+
+use std::{fmt::Display, path::PathBuf};
+
+use clap::Parser;
+use tracing::{error, warn};
+
+#[derive(Parser, Debug, Clone)]
+#[clap(author, version, about, long_about = None)]
+pub struct Args {
+ #[clap(default_value = "8033", long, env)]
+ pub http_port: u16,
+ #[clap(default_value = "8034", long, env)]
+ pub health_http_port: u16,
+ #[clap(
+ default_value = "config/config.yaml",
+ long,
+ env = "ORCHESTRATOR_CONFIG"
+ )]
+ pub config_path: PathBuf,
+ #[clap(long, env)]
+ pub tls_cert_path: Option,
+ #[clap(long, env)]
+ pub tls_key_path: Option,
+ #[clap(long, env)]
+ pub tls_client_ca_cert_path: Option,
+ #[clap(default_value = "false", long, env)]
+ pub start_up_health_check: bool,
+ #[clap(long, env, value_delimiter = ',')]
+ pub otlp_export: Vec,
+ #[clap(default_value_t = LogFormat::default(), long, env)]
+ pub log_format: LogFormat,
+ #[clap(default_value_t = false, long, short, env)]
+ pub quiet: bool,
+ #[clap(default_value = "fms_guardrails_orchestr8", long, env)]
+ pub otlp_service_name: String,
+ #[clap(long, env = "OTEL_EXPORTER_OTLP_ENDPOINT")]
+ pub otlp_endpoint: Option,
+ #[clap(long, env = "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT")]
+ pub otlp_traces_endpoint: Option,
+ #[clap(long, env = "OTEL_EXPORTER_OTLP_METRICS_ENDPOINT")]
+ pub otlp_metrics_endpoint: Option,
+ #[clap(
+ default_value_t = OtlpProtocol::Grpc,
+ long,
+ env = "OTEL_EXPORTER_OTLP_PROTOCOL"
+ )]
+ pub otlp_protocol: OtlpProtocol,
+ #[clap(long, env = "OTEL_EXPORTER_OTLP_TRACES_PROTOCOL")]
+ pub otlp_traces_protocol: Option,
+ #[clap(long, env = "OTEL_EXPORTER_OTLP_METRICS_PROTOCOL")]
+ pub otlp_metrics_protocol: Option,
+ // TODO: Add timeout and header OTLP variables
+}
+
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub enum OtlpExport {
+ Traces,
+ Metrics,
+}
+
+impl Display for OtlpExport {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ OtlpExport::Traces => write!(f, "traces"),
+ OtlpExport::Metrics => write!(f, "metrics"),
+ }
+ }
+}
+
+impl From for OtlpExport {
+ fn from(s: String) -> Self {
+ match s.to_lowercase().as_str() {
+ "traces" => OtlpExport::Traces,
+ "metrics" => OtlpExport::Metrics,
+ _ => panic!(
+ "Invalid OTLP export type {}, orchestrator only supports exporting traces and metrics via OTLP",
+ s
+ ),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy, Default)]
+pub enum OtlpProtocol {
+ #[default]
+ Grpc,
+ Http,
+}
+
+impl Display for OtlpProtocol {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ OtlpProtocol::Grpc => write!(f, "grpc"),
+ OtlpProtocol::Http => write!(f, "http"),
+ }
+ }
+}
+
+impl From for OtlpProtocol {
+ fn from(s: String) -> Self {
+ match s.to_lowercase().as_str() {
+ "grpc" => OtlpProtocol::Grpc,
+ "http" => OtlpProtocol::Http,
+ _ => {
+ error!(
+ "Invalid OTLP protocol {}, defaulting to {}",
+ s,
+ OtlpProtocol::default()
+ );
+ OtlpProtocol::default()
+ }
+ }
+ }
+}
+
+impl OtlpProtocol {
+ pub fn default_endpoint(&self) -> &str {
+ match self {
+ OtlpProtocol::Grpc => "http://localhost:4317",
+ OtlpProtocol::Http => "http://localhost:4318",
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy, Default, PartialEq)]
+pub enum LogFormat {
+ #[default]
+ Full,
+ Compact,
+ Pretty,
+ JSON,
+}
+
+impl Display for LogFormat {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ LogFormat::Full => write!(f, "full"),
+ LogFormat::Compact => write!(f, "compact"),
+ LogFormat::Pretty => write!(f, "pretty"),
+ LogFormat::JSON => write!(f, "json"),
+ }
+ }
+}
+
+impl From for LogFormat {
+ fn from(s: String) -> Self {
+ match s.to_lowercase().as_str() {
+ "full" => LogFormat::Full,
+ "compact" => LogFormat::Compact,
+ "pretty" => LogFormat::Pretty,
+ "json" => LogFormat::JSON,
+ _ => {
+ warn!(
+ "Invalid trace format {}, defaulting to {}",
+ s,
+ LogFormat::default()
+ );
+ LogFormat::default()
+ }
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct TracingConfig {
+ pub service_name: String,
+ pub traces: Option<(OtlpProtocol, String)>,
+ pub metrics: Option<(OtlpProtocol, String)>,
+ pub log_format: LogFormat,
+ pub quiet: bool,
+}
+
+impl From for TracingConfig {
+ fn from(args: Args) -> Self {
+ let otlp_protocol = args.otlp_protocol;
+ let otlp_endpoint = args
+ .otlp_endpoint
+ .unwrap_or(otlp_protocol.default_endpoint().to_string());
+ let otlp_traces_endpoint = args.otlp_traces_endpoint.unwrap_or(otlp_endpoint.clone());
+ let otlp_metrics_endpoint = args.otlp_metrics_endpoint.unwrap_or(otlp_endpoint.clone());
+ let otlp_traces_protocol = args.otlp_traces_protocol.unwrap_or(otlp_protocol);
+ let otlp_metrics_protocol = args.otlp_metrics_protocol.unwrap_or(otlp_protocol);
+
+ TracingConfig {
+ service_name: args.otlp_service_name,
+ traces: match args.otlp_export.contains(&OtlpExport::Traces) {
+ true => Some((otlp_traces_protocol, otlp_traces_endpoint)),
+ false => None,
+ },
+ metrics: match args.otlp_export.contains(&OtlpExport::Metrics) {
+ true => Some((otlp_metrics_protocol, otlp_metrics_endpoint)),
+ false => None,
+ },
+ log_format: args.log_format,
+ quiet: args.quiet,
+ }
+ }
+}
diff --git a/src/clients.rs b/src/clients.rs
index 4e40ae34..b7964934 100644
--- a/src/clients.rs
+++ b/src/clients.rs
@@ -16,29 +16,38 @@
*/
#![allow(dead_code)]
-// Import error for adding `source` trait
-use std::{collections::HashMap, error::Error as _, fmt::Display, pin::Pin, time::Duration};
+use std::{
+ any::TypeId,
+ collections::{hash_map, HashMap},
+ pin::Pin,
+ time::Duration,
+};
-use futures::{future::join_all, Stream};
+use async_trait::async_trait;
+use axum::http::{Extensions, HeaderMap};
+use futures::Stream;
use ginepro::LoadBalancedChannel;
-use reqwest::{Response, StatusCode};
use tokio::{fs::File, io::AsyncReadExt};
-use tracing::error;
+use tonic::{metadata::MetadataMap, Request};
+use tracing::{debug, instrument};
use url::Url;
use crate::{
config::{ServiceConfig, Tls},
- health::{HealthCheck, HealthCheckResult, HealthStatus, OptionalHealthCheckResponseBody},
+ health::HealthCheckResult,
+ tracing_utils::with_traceparent_header,
};
+pub mod errors;
+pub use errors::Error;
+
+pub mod http;
+pub use http::HttpClient;
+
pub mod chunker;
-pub use chunker::ChunkerClient;
pub mod detector;
-pub use detector::DetectorClient;
-
-pub mod generation;
-pub use generation::GenerationClient;
+pub use detector::TextContentsDetectorClient;
pub mod tgis;
pub use tgis::TgisClient;
@@ -46,358 +55,319 @@ pub use tgis::TgisClient;
pub mod nlp;
pub use nlp::NlpClient;
-pub const DEFAULT_TGIS_PORT: u16 = 8033;
-pub const DEFAULT_CAIKIT_NLP_PORT: u16 = 8085;
-pub const DEFAULT_CHUNKER_PORT: u16 = 8085;
-pub const DEFAULT_DETECTOR_PORT: u16 = 8080;
-pub const COMMON_ROUTER_KEY: &str = "common-router";
+pub mod generation;
+pub use generation::GenerationClient;
+
+pub mod nlp_http;
+pub use nlp_http::NlpClientHttp;
+
+pub mod openai;
+
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(60);
const DEFAULT_REQUEST_TIMEOUT_SEC: u64 = 600;
pub type BoxStream = Pin + Send>>;
-/// Client errors.
-#[derive(Debug, Clone, PartialEq, thiserror::Error)]
-pub enum Error {
- #[error("{}", .message)]
- Grpc { code: StatusCode, message: String },
- #[error("{}", .message)]
- Http { code: StatusCode, message: String },
- #[error("model not found: {model_id}")]
- ModelNotFound { model_id: String },
+mod private {
+ pub struct Seal;
}
-impl Error {
- /// Returns status code.
- pub fn status_code(&self) -> StatusCode {
- match self {
- // Return equivalent http status code for grpc status code
- Error::Grpc { code, .. } => *code,
- // Return http status code for error responses
- // and 500 for other errors
- Error::Http { code, .. } => *code,
- // Return 404 for model not found
- Error::ModelNotFound { .. } => StatusCode::NOT_FOUND,
- }
+#[async_trait]
+pub trait Client: Send + Sync + 'static {
+ /// Returns the name of the client type.
+ fn name(&self) -> &str;
+
+ /// Returns the `TypeId` of the client type. Sealed to prevent overrides.
+ fn type_id(&self, _: private::Seal) -> TypeId {
+ TypeId::of::()
}
+
+ /// Performs a client health check.
+ async fn health(&self) -> HealthCheckResult;
}
-impl From for Error {
- fn from(value: reqwest::Error) -> Self {
- // Log lower level source of error.
- // Examples:
- // 1. client error (Connect) // Cases like connection error, wrong port etc.
- // 2. client error (SendRequest) // Cases like cert issues
- error!(
- "http request failed. Source: {}",
- value.source().unwrap().to_string()
- );
- // Return http status code for error responses
- // and 500 for other errors
- let code = match value.status() {
- Some(code) => code,
- None => StatusCode::INTERNAL_SERVER_ERROR,
- };
- Self::Http {
- code,
- message: value.to_string(),
- }
+impl dyn Client {
+ pub fn is(&self) -> bool {
+ TypeId::of::() == self.type_id(private::Seal)
}
-}
-impl From for Error {
- fn from(value: tonic::Status) -> Self {
- use tonic::Code::*;
- // Return equivalent http status code for grpc status code
- let code = match value.code() {
- InvalidArgument => StatusCode::BAD_REQUEST,
- Internal => StatusCode::INTERNAL_SERVER_ERROR,
- NotFound => StatusCode::NOT_FOUND,
- DeadlineExceeded => StatusCode::REQUEST_TIMEOUT,
- Unimplemented => StatusCode::NOT_IMPLEMENTED,
- Unauthenticated => StatusCode::UNAUTHORIZED,
- PermissionDenied => StatusCode::FORBIDDEN,
- Unavailable => StatusCode::SERVICE_UNAVAILABLE,
- Ok => StatusCode::OK,
- _ => StatusCode::INTERNAL_SERVER_ERROR,
- };
- Self::Grpc {
- code,
- message: value.message().to_string(),
+ pub fn downcast(self: Box) -> Result, Box> {
+ if (*self).is::() {
+ let ptr = Box::into_raw(self) as *mut T;
+ // SAFETY: guaranteed by `is`
+ unsafe { Ok(Box::from_raw(ptr)) }
+ } else {
+ Err(self)
}
}
-}
-#[derive(Debug, Clone, PartialEq)]
-pub enum ClientCode {
- Http(StatusCode),
- Grpc(tonic::Code),
-}
+ pub fn downcast_ref(&self) -> Option<&T> {
+ if (*self).is::() {
+ let ptr = self as *const dyn Client as *const T;
+ // SAFETY: guaranteed by `is`
+ unsafe { Some(&*ptr) }
+ } else {
+ None
+ }
+ }
-impl Display for ClientCode {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- ClientCode::Http(code) => write!(f, "HTTP {}", code),
- ClientCode::Grpc(code) => write!(f, "gRPC {:?} {}", code, code),
+ pub fn downcast_mut(&mut self) -> Option<&mut T> {
+ if (*self).is::() {
+ let ptr = self as *mut dyn Client as *mut T;
+ // SAFETY: guaranteed by `is`
+ unsafe { Some(&mut *ptr) }
+ } else {
+ None
}
}
}
-#[derive(Clone)]
-pub struct HttpClient {
- base_url: Url,
- health_url: Url,
- client: reqwest::Client,
-}
+/// A map containing different types of clients.
+#[derive(Default)]
+pub struct ClientMap(HashMap>);
-impl HttpClient {
- pub fn new(base_url: Url, client: reqwest::Client) -> Self {
- let health_url = extract_base_url(&base_url).join("health").unwrap();
- Self {
- base_url,
- health_url,
- client,
- }
+impl ClientMap {
+ /// Creates an empty `ClientMap`.
+ #[inline]
+ pub fn new() -> Self {
+ Self(HashMap::new())
}
- pub fn base_url(&self) -> &Url {
- &self.base_url
+ /// Inserts a client into the map.
+ #[inline]
+ pub fn insert(&mut self, key: String, value: V) {
+ self.0.insert(key, Box::new(value));
}
- /// This is sectioned off to allow for testing.
- pub(super) async fn http_response_to_health_check_result(
- res: Result,
- ) -> HealthCheckResult {
- match res {
- Ok(response) => {
- if response.status() == StatusCode::OK {
- if let Ok(body) = response.json::().await {
- // If the service provided a body, we only anticipate a minimal health status and optional reason.
- HealthCheckResult {
- health_status: body.health_status.clone(),
- response_code: ClientCode::Http(StatusCode::OK),
- reason: match body.health_status {
- HealthStatus::Healthy => None,
- _ => body.reason,
- },
- }
- } else {
- // If the service did not provide a body, we assume it is healthy.
- HealthCheckResult {
- health_status: HealthStatus::Healthy,
- response_code: ClientCode::Http(StatusCode::OK),
- reason: None,
- }
- }
- } else {
- HealthCheckResult {
- // The most we can presume is that 5xx errors are likely indicating service issues, implying the service is unhealthy.
- // and that 4xx errors are more likely indicating health check failures, i.e. due to configuration/implementation issues.
- // Regardless we can't be certain, so the reason is also provided.
- // TODO: We will likely circle back to re-evaluate this logic in the future
- // when we know more about how the client health results will be used.
- health_status: if response.status().as_u16() >= 500
- && response.status().as_u16() < 600
- {
- HealthStatus::Unhealthy
- } else if response.status().as_u16() >= 400
- && response.status().as_u16() < 500
- {
- HealthStatus::Unknown
- } else {
- error!(
- "unexpected http health check status code: {}",
- response.status()
- );
- HealthStatus::Unknown
- },
- response_code: ClientCode::Http(response.status()),
- reason: Some(format!(
- "{}{}",
- response.error_for_status_ref().unwrap_err(),
- response
- .text()
- .await
- .map(|s| if s.is_empty() {
- "".to_string()
- } else {
- format!(": {}", s)
- })
- .unwrap_or("".to_string())
- )),
- }
- }
- }
- Err(e) => {
- error!("error checking health: {}", e);
- HealthCheckResult {
- health_status: HealthStatus::Unknown,
- response_code: ClientCode::Http(
- e.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
- ),
- reason: Some(e.to_string()),
- }
- }
- }
+ /// Returns a reference to the client trait object.
+ #[inline]
+ pub fn get(&self, key: &str) -> Option<&dyn Client> {
+ self.0.get(key).map(|v| v.as_ref())
}
-}
-impl HealthCheck for HttpClient {
- async fn check(&self) -> HealthCheckResult {
- let res = self.get(self.health_url.clone()).send().await;
- Self::http_response_to_health_check_result(res).await
+ /// Returns a mutable reference to the client trait object.
+ #[inline]
+ pub fn get_mut(&mut self, key: &str) -> Option<&mut dyn Client> {
+ self.0.get_mut(key).map(|v| v.as_mut())
+ }
+
+ /// Downcasts and returns a reference to the concrete client type.
+ #[inline]
+ pub fn get_as(&self, key: &str) -> Option<&V> {
+ self.0.get(key)?.downcast_ref::()
}
-}
-impl std::ops::Deref for HttpClient {
- type Target = reqwest::Client;
+ /// Downcasts and returns a mutable reference to the concrete client type.
+ #[inline]
+ pub fn get_mut_as(&mut self, key: &str) -> Option<&mut V> {
+ self.0.get_mut(key)?.downcast_mut::()
+ }
- fn deref(&self) -> &Self::Target {
- &self.client
+ /// Removes a client from the map.
+ #[inline]
+ pub fn remove(&mut self, key: &str) -> Option> {
+ self.0.remove(key)
+ }
+
+ /// An iterator visiting all key-value pairs in arbitrary order.
+ #[inline]
+ pub fn iter(&self) -> hash_map::Iter<'_, String, Box> {
+ self.0.iter()
+ }
+
+ /// An iterator visiting all keys in arbitrary order.
+ #[inline]
+ pub fn keys(&self) -> hash_map::Keys<'_, String, Box> {
+ self.0.keys()
+ }
+
+ /// An iterator visiting all values in arbitrary order.
+ #[inline]
+ pub fn values(&self) -> hash_map::Values<'_, String, Box> {
+ self.0.values()
+ }
+
+ /// Returns the number of elements in the map.
+ #[inline]
+ pub fn len(&self) -> usize {
+ self.0.len()
+ }
+
+ /// Returns `true` if the map contains no elements.
+ #[inline]
+ pub fn is_empty(&self) -> bool {
+ self.0.is_empty()
}
}
-pub async fn create_http_clients(
- default_port: u16,
- config: &[(String, ServiceConfig)],
-) -> HashMap {
- let clients = config
- .iter()
- .map(|(name, service_config)| async move {
- let port = service_config.port.unwrap_or(default_port);
- let mut base_url = Url::parse(&service_config.hostname).unwrap();
- base_url.set_port(Some(port)).unwrap();
- let request_timeout = Duration::from_secs(
- service_config
- .request_timeout
- .unwrap_or(DEFAULT_REQUEST_TIMEOUT_SEC),
- );
- let mut builder = reqwest::ClientBuilder::new()
- .connect_timeout(DEFAULT_CONNECT_TIMEOUT)
- .timeout(request_timeout);
- if let Some(Tls::Config(tls_config)) = &service_config.tls {
- let mut cert_buf = Vec::new();
- let cert_path = tls_config.cert_path.as_ref().unwrap().as_path();
- File::open(cert_path)
- .await
- .unwrap_or_else(|error| {
- panic!("error reading cert from {cert_path:?}: {error}")
- })
- .read_to_end(&mut cert_buf)
- .await
- .unwrap();
-
- if let Some(key_path) = &tls_config.key_path {
- File::open(key_path)
- .await
- .unwrap_or_else(|error| {
- panic!("error reading key from {key_path:?}: {error}")
- })
- .read_to_end(&mut cert_buf)
- .await
- .unwrap();
- }
- let identity = reqwest::Identity::from_pem(&cert_buf).unwrap_or_else(|error| {
- panic!("error parsing bundled client certificate: {error}")
- });
+#[instrument(skip_all, fields(hostname = service_config.hostname))]
+pub async fn create_http_client(default_port: u16, service_config: &ServiceConfig) -> HttpClient {
+ let port = service_config.port.unwrap_or(default_port);
+ let protocol = match service_config.tls {
+ Some(_) => "https",
+ None => "http",
+ };
+ let mut base_url = Url::parse(&format!("{}://{}", protocol, &service_config.hostname))
+ .unwrap_or_else(|e| panic!("error parsing base url: {}", e));
+ base_url
+ .set_port(Some(port))
+ .unwrap_or_else(|_| panic!("error setting port: {}", port));
+ debug!(%base_url, "creating HTTP client");
+ let request_timeout = Duration::from_secs(
+ service_config
+ .request_timeout
+ .unwrap_or(DEFAULT_REQUEST_TIMEOUT_SEC),
+ );
+ let mut builder = reqwest::ClientBuilder::new()
+ .connect_timeout(DEFAULT_CONNECT_TIMEOUT)
+ .timeout(request_timeout);
+ if let Some(Tls::Config(tls_config)) = &service_config.tls {
+ let mut cert_buf = Vec::new();
+ let cert_path = tls_config.cert_path.as_ref().unwrap().as_path();
+ File::open(cert_path)
+ .await
+ .unwrap_or_else(|error| panic!("error reading cert from {cert_path:?}: {error}"))
+ .read_to_end(&mut cert_buf)
+ .await
+ .unwrap();
+
+ if let Some(key_path) = &tls_config.key_path {
+ File::open(key_path)
+ .await
+ .unwrap_or_else(|error| panic!("error reading key from {key_path:?}: {error}"))
+ .read_to_end(&mut cert_buf)
+ .await
+ .unwrap();
+ }
+ let identity = reqwest::Identity::from_pem(&cert_buf)
+ .unwrap_or_else(|error| panic!("error parsing bundled client certificate: {error}"));
- builder = builder.use_rustls_tls().identity(identity);
- builder = builder.danger_accept_invalid_certs(tls_config.insecure.unwrap_or(false));
-
- if let Some(client_ca_cert_path) = &tls_config.client_ca_cert_path {
- let ca_cert =
- tokio::fs::read(client_ca_cert_path)
- .await
- .unwrap_or_else(|error| {
- panic!("error reading cert from {client_ca_cert_path:?}: {error}")
- });
- let cacert = reqwest::Certificate::from_pem(&ca_cert)
- .unwrap_or_else(|error| panic!("error parsing ca cert: {error}"));
- builder = builder.add_root_certificate(cacert)
- }
- }
- let client = builder
- .build()
- .unwrap_or_else(|error| panic!("error creating http client for {name}: {error}"));
- let client = HttpClient::new(base_url, client);
- (name.clone(), client)
- })
- .collect::>();
- join_all(clients).await.into_iter().collect()
+ builder = builder.use_rustls_tls().identity(identity);
+ builder = builder.danger_accept_invalid_certs(tls_config.insecure.unwrap_or(false));
+
+ if let Some(client_ca_cert_path) = &tls_config.client_ca_cert_path {
+ let ca_cert = tokio::fs::read(client_ca_cert_path)
+ .await
+ .unwrap_or_else(|error| {
+ panic!("error reading cert from {client_ca_cert_path:?}: {error}")
+ });
+ let cacert = reqwest::Certificate::from_pem(&ca_cert)
+ .unwrap_or_else(|error| panic!("error parsing ca cert: {error}"));
+ builder = builder.add_root_certificate(cacert)
+ }
+ }
+ let client = builder
+ .build()
+ .unwrap_or_else(|error| panic!("error creating http client: {error}"));
+ HttpClient::new(base_url, client)
}
-async fn create_grpc_clients(
+#[instrument(skip_all, fields(hostname = service_config.hostname))]
+pub async fn create_grpc_client(
default_port: u16,
- config: &[(String, ServiceConfig)],
+ service_config: &ServiceConfig,
new: fn(LoadBalancedChannel) -> C,
-) -> HashMap {
- let clients = config
- .iter()
- .map(|(name, service_config)| async move {
- let request_timeout = Duration::from_secs(service_config.request_timeout.unwrap_or(DEFAULT_REQUEST_TIMEOUT_SEC));
- let mut builder = LoadBalancedChannel::builder((
- service_config.hostname.clone(),
- service_config.port.unwrap_or(default_port),
- ))
- .connect_timeout(DEFAULT_CONNECT_TIMEOUT)
- .timeout(request_timeout);
-
- let client_tls_config = if let Some(Tls::Config(tls_config)) = &service_config.tls {
- let cert_path = tls_config.cert_path.as_ref().unwrap().as_path();
- let key_path = tls_config.key_path.as_ref().unwrap().as_path();
- let cert_pem = tokio::fs::read(cert_path)
- .await
- .unwrap_or_else(|error| panic!("error reading cert from {cert_path:?}: {error}"));
- let key_pem = tokio::fs::read(key_path)
+) -> C {
+ let port = service_config.port.unwrap_or(default_port);
+ let protocol = match service_config.tls {
+ Some(_) => "https",
+ None => "http",
+ };
+ let mut base_url = Url::parse(&format!("{}://{}", protocol, &service_config.hostname)).unwrap();
+ base_url.set_port(Some(port)).unwrap();
+ debug!(%base_url, "creating gRPC client");
+ let request_timeout = Duration::from_secs(
+ service_config
+ .request_timeout
+ .unwrap_or(DEFAULT_REQUEST_TIMEOUT_SEC),
+ );
+ let mut builder = LoadBalancedChannel::builder((service_config.hostname.clone(), port))
+ .connect_timeout(DEFAULT_CONNECT_TIMEOUT)
+ .timeout(request_timeout);
+
+ let client_tls_config = if let Some(Tls::Config(tls_config)) = &service_config.tls {
+ let cert_path = tls_config.cert_path.as_ref().unwrap().as_path();
+ let key_path = tls_config.key_path.as_ref().unwrap().as_path();
+ let cert_pem = tokio::fs::read(cert_path)
+ .await
+ .unwrap_or_else(|error| panic!("error reading cert from {cert_path:?}: {error}"));
+ let key_pem = tokio::fs::read(key_path)
+ .await
+ .unwrap_or_else(|error| panic!("error reading key from {key_path:?}: {error}"));
+ let identity = tonic::transport::Identity::from_pem(cert_pem, key_pem);
+ let mut client_tls_config = tonic::transport::ClientTlsConfig::new()
+ .identity(identity)
+ .with_native_roots()
+ .with_webpki_roots();
+ if let Some(client_ca_cert_path) = &tls_config.client_ca_cert_path {
+ let client_ca_cert_pem =
+ tokio::fs::read(client_ca_cert_path)
.await
- .unwrap_or_else(|error| panic!("error reading key from {key_path:?}: {error}"));
- let identity = tonic::transport::Identity::from_pem(cert_pem, key_pem);
- let mut client_tls_config =
- tonic::transport::ClientTlsConfig::new().identity(identity).with_native_roots().with_webpki_roots();
- if let Some(client_ca_cert_path) = &tls_config.client_ca_cert_path {
- let client_ca_cert_pem = tokio::fs::read(client_ca_cert_path)
- .await
- .unwrap_or_else(|error| {
- panic!("error reading client ca cert from {client_ca_cert_path:?}: {error}")
- });
- client_tls_config = client_tls_config.ca_certificate(
- tonic::transport::Certificate::from_pem(client_ca_cert_pem),
- );
- }
- Some(client_tls_config)
- } else {
- None
- };
- if let Some(client_tls_config) = client_tls_config {
- builder = builder.with_tls(client_tls_config);
- }
- let channel = builder.channel().await.unwrap_or_else(|error| panic!("error creating grpc client for {name}: {error}"));
- (name.clone(), new(channel))
- })
- .collect::>();
- join_all(clients).await.into_iter().collect()
+ .unwrap_or_else(|error| {
+ panic!("error reading client ca cert from {client_ca_cert_path:?}: {error}")
+ });
+ client_tls_config = client_tls_config
+ .ca_certificate(tonic::transport::Certificate::from_pem(client_ca_cert_pem));
+ }
+ Some(client_tls_config)
+ } else {
+ None
+ };
+ if let Some(client_tls_config) = client_tls_config {
+ builder = builder.with_tls(client_tls_config);
+ }
+ let channel = builder
+ .channel()
+ .await
+ .unwrap_or_else(|error| panic!("error creating grpc client: {error}"));
+ new(channel)
}
-/// Extracts a base url from a url including path segments.
-fn extract_base_url(url: &Url) -> Url {
- let mut url = url.clone();
- match url.path_segments_mut() {
- Ok(mut path) => {
- path.clear();
- }
- Err(_) => {
- panic!("url cannot be a base");
- }
+/// Returns `true` if hostname is valid according to [IETF RFC 1123](https://tools.ietf.org/html/rfc1123).
+///
+/// Conditions:
+/// - It does not start or end with `-` or `.`.
+/// - It does not contain any characters outside of the alphanumeric range, except for `-` and `.`.
+/// - It is not empty.
+/// - It is 253 or fewer characters.
+/// - Its labels (characters separated by `.`) are not empty.
+/// - Its labels are 63 or fewer characters.
+/// - Its labels do not start or end with '-' or '.'.
+pub fn is_valid_hostname(hostname: &str) -> bool {
+ fn is_valid_char(byte: u8) -> bool {
+ byte.is_ascii_lowercase()
+ || byte.is_ascii_uppercase()
+ || byte.is_ascii_digit()
+ || byte == b'-'
+ || byte == b'.'
}
- url
+ !(hostname.bytes().any(|byte| !is_valid_char(byte))
+ || hostname.split('.').any(|label| {
+ label.is_empty() || label.len() > 63 || label.starts_with('-') || label.ends_with('-')
+ })
+ || hostname.is_empty()
+ || hostname.len() > 253)
+}
+
+/// Turns a gRPC client request body of type `T` and header map into a `tonic::Request`.
+/// Will also inject the current `traceparent` header into the request based on the current span.
+fn grpc_request_with_headers(request: T, headers: HeaderMap) -> Request {
+ let headers = with_traceparent_header(headers);
+ let metadata = MetadataMap::from_headers(headers);
+ Request::from_parts(metadata, Extensions::new(), request)
}
#[cfg(test)]
mod tests {
- use hyper::http;
+ use errors::grpc_to_http_code;
+ use hyper::{http, StatusCode};
+ use reqwest::Response;
use super::*;
- use crate::pb::grpc::health::v1::{health_check_response::ServingStatus, HealthCheckResponse};
+ use crate::{
+ health::{HealthCheckResult, HealthStatus},
+ pb::grpc::health::v1::{health_check_response::ServingStatus, HealthCheckResponse},
+ };
async fn mock_http_response(
status: StatusCode,
@@ -429,83 +399,69 @@ mod tests {
// READY responses from HTTP 200 OK with or without reason
let response = [
(StatusCode::OK, r#"{}"#),
- (StatusCode::OK, r#"{ "health_status": "HEALTHY" }"#),
+ (StatusCode::OK, r#"{ "status": "HEALTHY" }"#),
+ (StatusCode::OK, r#"{ "status": "meaningless status" }"#),
(
StatusCode::OK,
- r#"{ "health_status": "meaningless status" }"#,
- ),
- (
- StatusCode::OK,
- r#"{ "health_status": "HEALTHY", "reason": "needless reason" }"#,
+ r#"{ "status": "HEALTHY", "reason": "needless reason" }"#,
),
];
for (status, body) in response.iter() {
let response = mock_http_response(*status, body).await;
let result = HttpClient::http_response_to_health_check_result(response).await;
- assert_eq!(result.health_status, HealthStatus::Healthy);
- assert_eq!(result.response_code, ClientCode::Http(StatusCode::OK));
+ assert_eq!(result.status, HealthStatus::Healthy);
+ assert_eq!(result.code, StatusCode::OK);
assert_eq!(result.reason, None);
let serialized = serde_json::to_string(&result).unwrap();
- assert_eq!(serialized, r#""HEALTHY""#);
+ assert_eq!(serialized, r#"{"status":"HEALTHY"}"#);
}
// NOT_READY response from HTTP 200 OK without reason
- let response =
- mock_http_response(StatusCode::OK, r#"{ "health_status": "UNHEALTHY" }"#).await;
+ let response = mock_http_response(StatusCode::OK, r#"{ "status": "UNHEALTHY" }"#).await;
let result = HttpClient::http_response_to_health_check_result(response).await;
- assert_eq!(result.health_status, HealthStatus::Unhealthy);
- assert_eq!(result.response_code, ClientCode::Http(StatusCode::OK));
+ assert_eq!(result.status, HealthStatus::Unhealthy);
+ assert_eq!(result.code, StatusCode::OK);
assert_eq!(result.reason, None);
let serialized = serde_json::to_string(&result).unwrap();
- assert_eq!(
- serialized,
- r#"{"health_status":"UNHEALTHY","response_code":"HTTP 200 OK"}"#
- );
+ assert_eq!(serialized, r#"{"status":"UNHEALTHY"}"#);
// UNKNOWN response from HTTP 200 OK without reason
- let response =
- mock_http_response(StatusCode::OK, r#"{ "health_status": "UNKNOWN" }"#).await;
+ let response = mock_http_response(StatusCode::OK, r#"{ "status": "UNKNOWN" }"#).await;
let result = HttpClient::http_response_to_health_check_result(response).await;
- assert_eq!(result.health_status, HealthStatus::Unknown);
- assert_eq!(result.response_code, ClientCode::Http(StatusCode::OK));
+ assert_eq!(result.status, HealthStatus::Unknown);
+ assert_eq!(result.code, StatusCode::OK);
assert_eq!(result.reason, None);
let serialized = serde_json::to_string(&result).unwrap();
- assert_eq!(
- serialized,
- r#"{"health_status":"UNKNOWN","response_code":"HTTP 200 OK"}"#
- );
+ assert_eq!(serialized, r#"{"status":"UNKNOWN"}"#);
// NOT_READY response from HTTP 200 OK with reason
let response = mock_http_response(
StatusCode::OK,
- r#"{ "health_status": "UNHEALTHY", "reason": "some reason" }"#,
+ r#"{"status": "UNHEALTHY", "reason": "some reason" }"#,
)
.await;
let result = HttpClient::http_response_to_health_check_result(response).await;
- assert_eq!(result.health_status, HealthStatus::Unhealthy);
- assert_eq!(result.response_code, ClientCode::Http(StatusCode::OK));
+ assert_eq!(result.status, HealthStatus::Unhealthy);
+ assert_eq!(result.code, StatusCode::OK);
assert_eq!(result.reason, Some("some reason".to_string()));
let serialized = serde_json::to_string(&result).unwrap();
assert_eq!(
serialized,
- r#"{"health_status":"UNHEALTHY","response_code":"HTTP 200 OK","reason":"some reason"}"#
+ r#"{"status":"UNHEALTHY","reason":"some reason"}"#
);
// UNKNOWN response from HTTP 200 OK with reason
let response = mock_http_response(
StatusCode::OK,
- r#"{ "health_status": "UNKNOWN", "reason": "some reason" }"#,
+ r#"{ "status": "UNKNOWN", "reason": "some reason" }"#,
)
.await;
let result = HttpClient::http_response_to_health_check_result(response).await;
- assert_eq!(result.health_status, HealthStatus::Unknown);
- assert_eq!(result.response_code, ClientCode::Http(StatusCode::OK));
+ assert_eq!(result.status, HealthStatus::Unknown);
+ assert_eq!(result.code, StatusCode::OK);
assert_eq!(result.reason, Some("some reason".to_string()));
let serialized = serde_json::to_string(&result).unwrap();
- assert_eq!(
- serialized,
- r#"{"health_status":"UNKNOWN","response_code":"HTTP 200 OK","reason":"some reason"}"#
- );
+ assert_eq!(serialized, r#"{"status":"UNKNOWN","reason":"some reason"}"#);
// NOT_READY response from HTTP 503 SERVICE UNAVAILABLE with reason
let response = mock_http_response(
@@ -514,16 +470,13 @@ mod tests {
)
.await;
let result = HttpClient::http_response_to_health_check_result(response).await;
- assert_eq!(result.health_status, HealthStatus::Unhealthy);
- assert_eq!(
- result.response_code,
- ClientCode::Http(StatusCode::SERVICE_UNAVAILABLE)
- );
- assert_eq!(result.reason, Some(r#"HTTP status server error (503 Service Unavailable) for url (http://no.url.provided.local/): { "message": "some error message" }"#.to_string()));
+ assert_eq!(result.status, HealthStatus::Unhealthy);
+ assert_eq!(result.code, StatusCode::SERVICE_UNAVAILABLE);
+ assert_eq!(result.reason, Some("Service Unavailable".to_string()));
let serialized = serde_json::to_string(&result).unwrap();
assert_eq!(
serialized,
- r#"{"health_status":"UNHEALTHY","response_code":"HTTP 503 Service Unavailable","reason":"HTTP status server error (503 Service Unavailable) for url (http://no.url.provided.local/): { \"message\": \"some error message\" }"}"#
+ r#"{"status":"UNHEALTHY","code":503,"reason":"Service Unavailable"}"#
);
// UNKNOWN response from HTTP 404 NOT FOUND with reason
@@ -533,46 +486,37 @@ mod tests {
)
.await;
let result = HttpClient::http_response_to_health_check_result(response).await;
- assert_eq!(result.health_status, HealthStatus::Unknown);
- assert_eq!(
- result.response_code,
- ClientCode::Http(StatusCode::NOT_FOUND)
- );
- assert_eq!(result.reason, Some(r#"HTTP status client error (404 Not Found) for url (http://no.url.provided.local/): { "message": "service not found" }"#.to_string()));
+ assert_eq!(result.status, HealthStatus::Unknown);
+ assert_eq!(result.code, StatusCode::NOT_FOUND);
+ assert_eq!(result.reason, Some("Not Found".to_string()));
let serialized = serde_json::to_string(&result).unwrap();
assert_eq!(
serialized,
- r#"{"health_status":"UNKNOWN","response_code":"HTTP 404 Not Found","reason":"HTTP status client error (404 Not Found) for url (http://no.url.provided.local/): { \"message\": \"service not found\" }"}"#
+ r#"{"status":"UNKNOWN","code":404,"reason":"Not Found"}"#
);
// NOT_READY response from HTTP 500 INTERNAL SERVER ERROR without reason
let response = mock_http_response(StatusCode::INTERNAL_SERVER_ERROR, r#""#).await;
let result = HttpClient::http_response_to_health_check_result(response).await;
- assert_eq!(result.health_status, HealthStatus::Unhealthy);
- assert_eq!(
- result.response_code,
- ClientCode::Http(StatusCode::INTERNAL_SERVER_ERROR)
- );
- assert_eq!(result.reason, Some("HTTP status server error (500 Internal Server Error) for url (http://no.url.provided.local/)".to_string()));
+ assert_eq!(result.status, HealthStatus::Unhealthy);
+ assert_eq!(result.code, StatusCode::INTERNAL_SERVER_ERROR);
+ assert_eq!(result.reason, Some("Internal Server Error".to_string()));
let serialized = serde_json::to_string(&result).unwrap();
assert_eq!(
serialized,
- r#"{"health_status":"UNHEALTHY","response_code":"HTTP 500 Internal Server Error","reason":"HTTP status server error (500 Internal Server Error) for url (http://no.url.provided.local/)"}"#
+ r#"{"status":"UNHEALTHY","code":500,"reason":"Internal Server Error"}"#
);
// UNKNOWN response from HTTP 400 BAD REQUEST without reason
let response = mock_http_response(StatusCode::BAD_REQUEST, r#""#).await;
let result = HttpClient::http_response_to_health_check_result(response).await;
- assert_eq!(result.health_status, HealthStatus::Unknown);
- assert_eq!(
- result.response_code,
- ClientCode::Http(StatusCode::BAD_REQUEST)
- );
- assert_eq!(result.reason, Some("HTTP status client error (400 Bad Request) for url (http://no.url.provided.local/)".to_string()));
+ assert_eq!(result.status, HealthStatus::Unknown);
+ assert_eq!(result.code, StatusCode::BAD_REQUEST);
+ assert_eq!(result.reason, Some("Bad Request".to_string()));
let serialized = serde_json::to_string(&result).unwrap();
assert_eq!(
serialized,
- r#"{"health_status":"UNKNOWN","response_code":"HTTP 400 Bad Request","reason":"HTTP status client error (400 Bad Request) for url (http://no.url.provided.local/)"}"#
+ r#"{"status":"UNKNOWN","code":400,"reason":"Bad Request"}"#
);
}
@@ -581,59 +525,47 @@ mod tests {
// READY responses from gRPC 0 OK from serving status 1 SERVING
let response = mock_grpc_response(Some(ServingStatus::Serving as i32), None).await;
let result = HealthCheckResult::from(response);
- assert_eq!(result.health_status, HealthStatus::Healthy);
- assert_eq!(result.response_code, ClientCode::Grpc(tonic::Code::Ok));
+ assert_eq!(result.status, HealthStatus::Healthy);
+ assert_eq!(result.code, grpc_to_http_code(tonic::Code::Ok));
assert_eq!(result.reason, None);
let serialized = serde_json::to_string(&result).unwrap();
- assert_eq!(serialized, r#""HEALTHY""#);
+ assert_eq!(serialized, r#"{"status":"HEALTHY"}"#);
// NOT_READY response from gRPC 0 OK form serving status 2 NOT_SERVING
let response = mock_grpc_response(Some(ServingStatus::NotServing as i32), None).await;
let result = HealthCheckResult::from(response);
- assert_eq!(result.health_status, HealthStatus::Unhealthy);
- assert_eq!(result.response_code, ClientCode::Grpc(tonic::Code::Ok));
- assert_eq!(
- result.reason,
- Some("from gRPC health check serving status: NOT_SERVING".to_string())
- );
+ assert_eq!(result.status, HealthStatus::Unhealthy);
+ assert_eq!(result.code, grpc_to_http_code(tonic::Code::Ok));
+ assert_eq!(result.reason, Some("NOT_SERVING".to_string()));
let serialized = serde_json::to_string(&result).unwrap();
assert_eq!(
serialized,
- r#"{"health_status":"UNHEALTHY","response_code":"gRPC Ok The operation completed successfully","reason":"from gRPC health check serving status: NOT_SERVING"}"#
+ r#"{"status":"UNHEALTHY","reason":"NOT_SERVING"}"#
);
// UNKNOWN response from gRPC 0 OK from serving status 0 UNKNOWN
let response = mock_grpc_response(Some(ServingStatus::Unknown as i32), None).await;
let result = HealthCheckResult::from(response);
- assert_eq!(result.health_status, HealthStatus::Unknown);
- assert_eq!(result.response_code, ClientCode::Grpc(tonic::Code::Ok));
- assert_eq!(
- result.reason,
- Some("from gRPC health check serving status: UNKNOWN".to_string())
- );
+ assert_eq!(result.status, HealthStatus::Unknown);
+ assert_eq!(result.code, grpc_to_http_code(tonic::Code::Ok));
+ assert_eq!(result.reason, Some("UNKNOWN".to_string()));
let serialized = serde_json::to_string(&result).unwrap();
- assert_eq!(
- serialized,
- r#"{"health_status":"UNKNOWN","response_code":"gRPC Ok The operation completed successfully","reason":"from gRPC health check serving status: UNKNOWN"}"#
- );
+ assert_eq!(serialized, r#"{"status":"UNKNOWN","reason":"UNKNOWN"}"#);
// UNKNOWN response from gRPC 0 OK from serving status 3 SERVICE_UNKNOWN
let response = mock_grpc_response(Some(ServingStatus::ServiceUnknown as i32), None).await;
let result = HealthCheckResult::from(response);
- assert_eq!(result.health_status, HealthStatus::Unknown);
- assert_eq!(result.response_code, ClientCode::Grpc(tonic::Code::Ok));
- assert_eq!(
- result.reason,
- Some("from gRPC health check serving status: SERVICE_UNKNOWN".to_string())
- );
+ assert_eq!(result.status, HealthStatus::Unknown);
+ assert_eq!(result.code, grpc_to_http_code(tonic::Code::Ok));
+ assert_eq!(result.reason, Some("SERVICE_UNKNOWN".to_string()));
let serialized = serde_json::to_string(&result).unwrap();
assert_eq!(
serialized,
- r#"{"health_status":"UNKNOWN","response_code":"gRPC Ok The operation completed successfully","reason":"from gRPC health check serving status: SERVICE_UNKNOWN"}"#
+ r#"{"status":"UNKNOWN","reason":"SERVICE_UNKNOWN"}"#
);
// UNKNOWN response from other gRPC error codes (covering main ones)
- let response_codes = [
+ let codes = [
tonic::Code::InvalidArgument,
tonic::Code::Internal,
tonic::Code::NotFound,
@@ -642,40 +574,40 @@ mod tests {
tonic::Code::PermissionDenied,
tonic::Code::Unavailable,
];
- for code in response_codes.iter() {
- let status = tonic::Status::new(*code, "some error message");
+ for code in codes.into_iter() {
+ let status = tonic::Status::new(code, "some error message");
+ let code = grpc_to_http_code(code);
let response = mock_grpc_response(None, Some(status.clone())).await;
let result = HealthCheckResult::from(response);
- assert_eq!(result.health_status, HealthStatus::Unknown);
- assert_eq!(result.response_code, ClientCode::Grpc(*code));
- assert_eq!(
- result.reason,
- Some(format!("gRPC health check failed: {}", status.clone()))
- );
+ assert_eq!(result.status, HealthStatus::Unknown);
+ assert_eq!(result.code, code);
+ assert_eq!(result.reason, Some("some error message".to_string()));
let serialized = serde_json::to_string(&result).unwrap();
assert_eq!(
serialized,
format!(
- r#"{{"health_status":"UNKNOWN","response_code":"gRPC {:?} {}","reason":"gRPC health check failed: status: {:?}, message: \"some error message\", details: [], metadata: MetadataMap {{ headers: {{}} }}"}}"#,
- code, code, code
- )
+ r#"{{"status":"UNKNOWN","code":{},"reason":"some error message"}}"#,
+ code.as_u16()
+ ),
);
}
}
#[test]
- fn test_extract_base_url() {
- let url =
- Url::parse("https://example-detector.route.example.com/api/v1/text/contents").unwrap();
- let base_url = extract_base_url(&url);
- assert_eq!(
- Url::parse("https://example-detector.route.example.com/").unwrap(),
- base_url
- );
- let health_url = base_url.join("/health").unwrap();
- assert_eq!(
- Url::parse("https://example-detector.route.example.com/health").unwrap(),
- health_url
- );
+ fn test_is_valid_hostname() {
+ let valid_hostnames = ["localhost", "example.route.cloud.com", "127.0.0.1"];
+ for hostname in valid_hostnames {
+ assert!(is_valid_hostname(hostname));
+ }
+ let invalid_hostnames = [
+ "-LoCaLhOST_",
+ ".invalid",
+ "invalid.ending-.char",
+ "@asdf",
+ "too-long-of-a-hostnameeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee",
+ ];
+ for hostname in invalid_hostnames {
+ assert!(!is_valid_hostname(hostname));
+ }
}
}
diff --git a/src/clients/chunker.rs b/src/clients/chunker.rs
index 35b04591..bd924b90 100644
--- a/src/clients/chunker.rs
+++ b/src/clients/chunker.rs
@@ -15,19 +15,22 @@
*/
-use std::{collections::HashMap, pin::Pin};
+use std::pin::Pin;
+use async_trait::async_trait;
+use axum::http::HeaderMap;
use futures::{Future, Stream, StreamExt, TryStreamExt};
use ginepro::LoadBalancedChannel;
-use tokio::sync::mpsc;
-use tokio_stream::wrappers::ReceiverStream;
-use tonic::{Request, Response, Status, Streaming};
-use tracing::info;
+use tonic::{Code, Request, Response, Status, Streaming};
+use tracing::{info, instrument};
-use super::{create_grpc_clients, BoxStream, Error};
+use super::{
+ create_grpc_client, errors::grpc_to_http_code, grpc_request_with_headers, BoxStream, Client,
+ Error,
+};
use crate::{
config::ServiceConfig,
- health::{HealthCheckResult, HealthProbe},
+ health::{HealthCheckResult, HealthStatus},
pb::{
caikit::runtime::chunkers::{
chunkers_service_client::ChunkersServiceClient,
@@ -36,115 +39,104 @@ use crate::{
caikit_data_model::nlp::{ChunkerTokenizationStreamResult, Token, TokenizationResults},
grpc::health::v1::{health_client::HealthClient, HealthCheckRequest},
},
+ tracing_utils::trace_context_from_grpc_response,
};
+const DEFAULT_PORT: u16 = 8085;
const MODEL_ID_HEADER_NAME: &str = "mm-model-id";
/// Default chunker that returns span for entire text
-pub const DEFAULT_MODEL_ID: &str = "whole_doc_chunker";
+pub const DEFAULT_CHUNKER_ID: &str = "whole_doc_chunker";
type StreamingTokenizationResult =
Result>, Status>;
-#[cfg_attr(test, faux::create, derive(Default))]
+#[cfg_attr(test, faux::create)]
#[derive(Clone)]
pub struct ChunkerClient {
- clients: HashMap>,
- health_clients: HashMap>,
-}
-
-#[cfg_attr(test, faux::methods)]
-impl HealthProbe for ChunkerClient {
- async fn health(&self) -> Result, Error> {
- let mut results = HashMap::with_capacity(self.health_clients.len());
- for (model_id, mut client) in self.health_clients.clone() {
- results.insert(
- model_id.clone(),
- client
- .check(HealthCheckRequest {
- service: "".to_string(),
- }) // Caikit does not expect a service_id to be specified
- .await
- .into(),
- );
- }
- Ok(results)
- }
+ client: ChunkersServiceClient,
+ health_client: HealthClient,
}
#[cfg_attr(test, faux::methods)]
impl ChunkerClient {
- pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self {
- let clients = create_grpc_clients(default_port, config, ChunkersServiceClient::new).await;
- let health_clients = create_grpc_clients(default_port, config, HealthClient::new).await;
+ pub async fn new(config: &ServiceConfig) -> Self {
+ let client = create_grpc_client(DEFAULT_PORT, config, ChunkersServiceClient::new).await;
+ let health_client = create_grpc_client(DEFAULT_PORT, config, HealthClient::new).await;
Self {
- clients,
- health_clients,
+ client,
+ health_client,
}
}
- fn client(&self, model_id: &str) -> Result, Error> {
- Ok(self
- .clients
- .get(model_id)
- .ok_or_else(|| Error::ModelNotFound {
- model_id: model_id.to_string(),
- })?
- .clone())
- }
-
+ #[instrument(skip_all, fields(model_id))]
pub async fn tokenization_task_predict(
&self,
model_id: &str,
request: ChunkerTokenizationTaskRequest,
) -> Result {
- // Handle "default" separately first
- if model_id == DEFAULT_MODEL_ID {
- info!("Using default whole doc chunker");
- return Ok(tokenize_whole_doc(request));
- }
- let request = request_with_model_id(request, model_id);
- Ok(self
- .client(model_id)?
- .chunker_tokenization_task_predict(request)
- .await?
- .into_inner())
+ let mut client = self.client.clone();
+ let request = request_with_headers(request, model_id);
+ info!(?request, "sending client request");
+ let response = client.chunker_tokenization_task_predict(request).await?;
+ trace_context_from_grpc_response(&response);
+ Ok(response.into_inner())
}
+ #[instrument(skip_all, fields(model_id))]
pub async fn bidi_streaming_tokenization_task_predict(
&self,
model_id: &str,
request_stream: BoxStream,
) -> Result>, Error> {
- let response_stream = if model_id == DEFAULT_MODEL_ID {
- info!("Using default whole doc chunker");
- let (response_tx, response_rx) = mpsc::channel(1);
- // Spawn task to collect input stream
- tokio::spawn(async move {
- // NOTE: this will not resolve until the input stream is closed
- let response = tokenize_whole_doc_stream(request_stream).await;
- let _ = response_tx.send(response).await;
- });
- ReceiverStream::new(response_rx).boxed()
+ info!("sending client stream request");
+ let mut client = self.client.clone();
+ let request = request_with_headers(request_stream, model_id);
+ // NOTE: this is an ugly workaround to avoid bogus higher-ranked lifetime errors.
+ // https://github.com/rust-lang/rust/issues/110338
+ let response_stream_fut: Pin + Send>> =
+ Box::pin(client.bidi_streaming_chunker_tokenization_task_predict(request));
+ let response_stream = response_stream_fut.await?;
+ trace_context_from_grpc_response(&response_stream);
+ Ok(response_stream.into_inner().map_err(Into::into).boxed())
+ }
+}
+
+#[cfg_attr(test, faux::methods)]
+#[async_trait]
+impl Client for ChunkerClient {
+ fn name(&self) -> &str {
+ "chunker"
+ }
+
+ async fn health(&self) -> HealthCheckResult {
+ let mut client = self.health_client.clone();
+ let response = client
+ .check(HealthCheckRequest { service: "".into() })
+ .await;
+ let code = match response {
+ Ok(_) => Code::Ok,
+ Err(status) if matches!(status.code(), Code::InvalidArgument | Code::NotFound) => {
+ Code::Ok
+ }
+ Err(status) => status.code(),
+ };
+ let status = if matches!(code, Code::Ok) {
+ HealthStatus::Healthy
} else {
- let mut client = self.client(model_id)?;
- let request = request_with_model_id(request_stream, model_id);
- // NOTE: this is an ugly workaround to avoid bogus higher-ranked lifetime errors.
- // https://github.com/rust-lang/rust/issues/110338
- let response_stream_fut: Pin<
- Box + Send>,
- > = Box::pin(client.bidi_streaming_chunker_tokenization_task_predict(request));
- response_stream_fut
- .await?
- .into_inner()
- .map_err(Into::into)
- .boxed()
+ HealthStatus::Unhealthy
};
- Ok(response_stream)
+ HealthCheckResult {
+ status,
+ code: grpc_to_http_code(code),
+ reason: None,
+ }
}
}
-fn request_with_model_id(request: T, model_id: &str) -> Request {
- let mut request = Request::new(request);
+/// Turns a chunker client gRPC request body of type `T` into a `tonic::Request` with headers.
+/// Adds the provided `model_id` as a header as well as injects `traceparent` from the current span.
+fn request_with_headers(request: T, model_id: &str) -> Request {
+ let mut request = grpc_request_with_headers(request, HeaderMap::new());
request
.metadata_mut()
.insert(MODEL_ID_HEADER_NAME, model_id.parse().unwrap());
@@ -152,7 +144,8 @@ fn request_with_model_id(request: T, model_id: &str) -> Request {
}
/// Unary tokenization result of the entire doc
-fn tokenize_whole_doc(request: ChunkerTokenizationTaskRequest) -> TokenizationResults {
+#[instrument(skip_all)]
+pub fn tokenize_whole_doc(request: ChunkerTokenizationTaskRequest) -> TokenizationResults {
let codepoint_count = request.text.chars().count() as i64;
TokenizationResults {
results: vec![Token {
@@ -165,7 +158,8 @@ fn tokenize_whole_doc(request: ChunkerTokenizationTaskRequest) -> TokenizationRe
}
/// Streaming tokenization result for the entire doc stream
-async fn tokenize_whole_doc_stream(
+#[instrument(skip_all)]
+pub async fn tokenize_whole_doc_stream(
request: impl Stream- ,
) -> Result {
let (text, index_vec): (String, Vec) = request
diff --git a/src/clients/detector.rs b/src/clients/detector.rs
index 455612bc..cf06ddef 100644
--- a/src/clients/detector.rs
+++ b/src/clients/detector.rs
@@ -15,234 +15,30 @@
*/
-use std::collections::HashMap;
+use std::fmt::Debug;
-use hyper::{HeaderMap, StatusCode};
+use axum::http::HeaderMap;
+use hyper::StatusCode;
+use reqwest::Response;
use serde::{Deserialize, Serialize};
-
-use super::{create_http_clients, Error, HttpClient};
-use crate::{
- config::ServiceConfig,
- health::{HealthCheck, HealthCheckResult, HealthProbe},
- models::{DetectionResult, DetectorParams},
-};
-
+use tracing::info;
+use url::Url;
+
+pub mod text_contents;
+pub use text_contents::*;
+pub mod text_chat;
+pub use text_chat::*;
+pub mod text_context_doc;
+pub use text_context_doc::*;
+pub mod text_generation;
+pub use text_generation::*;
+
+use super::{Error, HttpClient};
+use crate::tracing_utils::{trace_context_from_http_response, with_traceparent_header};
+
+const DEFAULT_PORT: u16 = 8080;
const DETECTOR_ID_HEADER_NAME: &str = "detector-id";
-// For some reason the order matters here. #[cfg_attr(test, derive(Default), faux::create)] doesn't work. (rustc --explain E0560)
-#[cfg_attr(test, faux::create, derive(Default))]
-#[derive(Clone)]
-pub struct DetectorClient {
- clients: HashMap,
-}
-
-#[cfg_attr(test, faux::methods)]
-impl HealthProbe for DetectorClient {
- async fn health(&self) -> Result, Error> {
- let mut results = HashMap::with_capacity(self.clients.len());
- for (model_id, client) in self.clients() {
- results.insert(model_id.to_string(), client.check().await);
- }
- Ok(results)
- }
-}
-
-#[cfg_attr(test, faux::methods)]
-impl DetectorClient {
- pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self {
- let clients: HashMap = create_http_clients(default_port, config).await;
- Self { clients }
- }
-
- fn client(&self, model_id: &str) -> Result {
- Ok(self
- .clients
- .get(model_id)
- .ok_or_else(|| Error::ModelNotFound {
- model_id: model_id.to_string(),
- })?
- .clone())
- }
-
- fn clients(&self) -> impl Iterator
- {
- self.clients.iter()
- }
-
- // TODO: Use generics here, since the only thing that changes in comparison to generation_detection()
- // is the "request" parameter and return types?
- /// Invokes detectors implemented with the `/api/v1/text/contents` endpoint
- pub async fn text_contents(
- &self,
- model_id: &str,
- request: ContentAnalysisRequest,
- headers: HeaderMap,
- ) -> Result>, Error> {
- let client = self.client(model_id)?;
- let url = client.base_url().as_str();
- let response = client
- .post(url)
- .headers(headers)
- .header(DETECTOR_ID_HEADER_NAME, model_id)
- .json(&request)
- .send()
- .await?;
- if response.status() == StatusCode::OK {
- Ok(response.json().await?)
- } else {
- let code = response.status().as_u16();
- let error = response
- .json::()
- .await
- .unwrap_or(DetectorError {
- code,
- message: "".into(),
- });
- Err(error.into())
- }
- }
-
- /// Invokes detectors implemented with the `/api/v1/text/generation` endpoint
- pub async fn text_generation(
- &self,
- model_id: &str,
- request: GenerationDetectionRequest,
- headers: HeaderMap,
- ) -> Result, Error> {
- let client = self.client(model_id)?;
- let url = client.base_url().as_str();
- let response = client
- .post(url)
- .headers(headers)
- .header(DETECTOR_ID_HEADER_NAME, model_id)
- .json(&request)
- .send()
- .await?;
- if response.status() == StatusCode::OK {
- Ok(response.json().await?)
- } else {
- let code = response.status().as_u16();
- let error = response
- .json::()
- .await
- .unwrap_or(DetectorError {
- code,
- message: "".into(),
- });
- Err(error.into())
- }
- }
-
- /// Invokes detectors implemented with the `/api/v1/text/context/doc` endpoint
- pub async fn text_context_doc(
- &self,
- model_id: &str,
- request: ContextDocsDetectionRequest,
- headers: HeaderMap,
- ) -> Result, Error> {
- let client = self.client(model_id)?;
- let url = client.base_url().as_str();
- let response = client
- .post(url)
- .headers(headers)
- .header(DETECTOR_ID_HEADER_NAME, model_id)
- .json(&request)
- .send()
- .await?;
- if response.status() == StatusCode::OK {
- Ok(response.json().await?)
- } else {
- let code = response.status().as_u16();
- let error = response
- .json::()
- .await
- .unwrap_or(DetectorError {
- code,
- message: "".into(),
- });
- Err(error.into())
- }
- }
-}
-
-/// Request for text content analysis
-/// Results of this request will contain analysis / detection of each of the provided documents
-/// in the order they are present in the `contents` object.
-#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
-pub struct ContentAnalysisRequest {
- /// Field allowing users to provide list of documents for analysis
- pub contents: Vec,
-}
-
-impl ContentAnalysisRequest {
- pub fn new(contents: Vec) -> ContentAnalysisRequest {
- ContentAnalysisRequest { contents }
- }
-}
-
-/// Evidence
-#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
-pub struct Evidence {
- /// Evidence name
- pub name: String,
- /// Optional, evidence value
- #[serde(skip_serializing_if = "Option::is_none")]
- pub value: Option,
- /// Optional, score for evidence
- #[serde(skip_serializing_if = "Option::is_none")]
- pub score: Option,
-}
-
-/// Evidence in response
-#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
-pub struct EvidenceObj {
- /// Evidence name
- pub name: String,
- /// Optional, evidence value
- #[serde(skip_serializing_if = "Option::is_none")]
- pub value: Option,
- /// Optional, score for evidence
- #[serde(skip_serializing_if = "Option::is_none")]
- pub score: Option,
- /// Optional, evidence on evidence value
- // Evidence nesting should likely not go beyond this
- #[serde(skip_serializing_if = "Option::is_none")]
- pub evidence: Option>,
-}
-
-/// Response of text content analysis endpoint
-#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
-pub struct ContentAnalysisResponse {
- /// Start index of detection
- pub start: usize,
- /// End index of detection
- pub end: usize,
- /// Text corresponding to detection
- pub text: String,
- /// Relevant detection class
- pub detection: String,
- /// Detection type or aggregate detection label
- pub detection_type: String,
- /// Score of detection
- pub score: f64,
- /// Optional, any applicable evidence for detection
- #[serde(skip_serializing_if = "Option::is_none")]
- pub evidence: Option>,
-}
-
-impl From for crate::models::TokenClassificationResult {
- fn from(value: ContentAnalysisResponse) -> Self {
- Self {
- start: value.start as u32,
- end: value.end as u32,
- word: value.text,
- entity: value.detection,
- entity_group: value.detection_type,
- score: value.score,
- token_count: None,
- }
- }
-}
-
#[derive(Debug, Clone, Deserialize)]
pub struct DetectorError {
pub code: u16,
@@ -258,66 +54,24 @@ impl From for Error {
}
}
-/// A struct representing a request to a detector compatible with the
-/// /api/v1/text/generation endpoint.
-#[cfg_attr(test, derive(PartialEq))]
-#[derive(Debug, Serialize)]
-pub struct GenerationDetectionRequest {
- /// User prompt sent to LLM
- pub prompt: String,
-
- /// Text generated from an LLM
- pub generated_text: String,
-}
-
-impl GenerationDetectionRequest {
- pub fn new(prompt: String, generated_text: String) -> Self {
- Self {
- prompt,
- generated_text,
- }
- }
-}
-
-/// A struct representing a request to a detector compatible with the
-/// /api/v1/text/context/doc endpoint.
-#[cfg_attr(test, derive(PartialEq))]
-#[derive(Debug, Serialize)]
-pub struct ContextDocsDetectionRequest {
- /// Content to run detection on
- pub content: String,
-
- /// Type of context being sent
- pub context_type: ContextType,
-
- /// Context to run detection on
- pub context: Vec,
-
- // Detector Params
- pub detector_params: DetectorParams,
-}
-
-/// Enum representing the context type of a detection
-#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
-pub enum ContextType {
- #[serde(rename = "docs")]
- Document,
- #[serde(rename = "url")]
- Url,
-}
-
-impl ContextDocsDetectionRequest {
- pub fn new(
- content: String,
- context_type: ContextType,
- context: Vec,
- detector_params: DetectorParams,
- ) -> Self {
- Self {
- content,
- context_type,
- context,
- detector_params,
- }
- }
+/// Make a POST request for an HTTP detector client and return the response.
+/// Also injects the `traceparent` header from the current span and traces the response.
+pub async fn post_with_headers(
+ client: HttpClient,
+ url: Url,
+ request: T,
+ headers: HeaderMap,
+ model_id: &str,
+) -> Result {
+ let headers = with_traceparent_header(headers);
+ info!(?url, ?headers, ?request, "sending client request");
+ let response = client
+ .post(url)
+ .headers(headers)
+ .header(DETECTOR_ID_HEADER_NAME, model_id)
+ .json(&request)
+ .send()
+ .await?;
+ trace_context_from_http_response(&response);
+ Ok(response)
}
diff --git a/src/clients/detector/text_chat.rs b/src/clients/detector/text_chat.rs
new file mode 100644
index 00000000..bb6fc524
--- /dev/null
+++ b/src/clients/detector/text_chat.rs
@@ -0,0 +1,125 @@
+/*
+ Copyright FMS Guardrails Orchestrator Authors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+*/
+
+use async_trait::async_trait;
+use hyper::{HeaderMap, StatusCode};
+use serde::Serialize;
+use tracing::{debug, info, instrument};
+
+use super::{DetectorError, DEFAULT_PORT, DETECTOR_ID_HEADER_NAME};
+use crate::{
+ clients::{create_http_client, openai::Message, Client, Error, HttpClient},
+ config::ServiceConfig,
+ health::HealthCheckResult,
+ models::{DetectionResult, DetectorParams},
+};
+
+const CHAT_DETECTOR_ENDPOINT: &str = "/api/v1/text/chat";
+
+#[cfg_attr(test, faux::create)]
+#[derive(Clone)]
+pub struct TextChatDetectorClient {
+ client: HttpClient,
+ health_client: Option,
+}
+
+#[cfg_attr(test, faux::methods)]
+impl TextChatDetectorClient {
+ pub async fn new(config: &ServiceConfig, health_config: Option<&ServiceConfig>) -> Self {
+ let client = create_http_client(DEFAULT_PORT, config).await;
+ let health_client = if let Some(health_config) = health_config {
+ Some(create_http_client(DEFAULT_PORT, health_config).await)
+ } else {
+ None
+ };
+ Self {
+ client,
+ health_client,
+ }
+ }
+
+ #[instrument(skip_all, fields(model_id, ?headers))]
+ pub async fn text_chat(
+ &self,
+ model_id: &str,
+ request: ChatDetectionRequest,
+ headers: HeaderMap,
+ ) -> Result, Error> {
+ let url = self.client.base_url().join(CHAT_DETECTOR_ENDPOINT).unwrap();
+ info!(?url, "sending chat detector client request");
+ let request = self
+ .client
+ .post(url)
+ .headers(headers)
+ .header(DETECTOR_ID_HEADER_NAME, model_id)
+ .json(&request);
+ debug!("chat detector client request: {:?}", request);
+ let response = request.send().await?;
+ debug!("chat detector client response: {:?}", response);
+
+ if response.status() == StatusCode::OK {
+ Ok(response.json().await?)
+ } else {
+ let code = response.status().as_u16();
+ let error = response
+ .json::()
+ .await
+ .unwrap_or(DetectorError {
+ code,
+ message: "".into(),
+ });
+ Err(error.into())
+ }
+ }
+}
+
+#[cfg_attr(test, faux::methods)]
+#[async_trait]
+impl Client for TextChatDetectorClient {
+ fn name(&self) -> &str {
+ "text_chat_detector"
+ }
+
+ async fn health(&self) -> HealthCheckResult {
+ if let Some(health_client) = &self.health_client {
+ health_client.health().await
+ } else {
+ self.client.health().await
+ }
+ }
+}
+
+/// A struct representing a request to a detector compatible with the
+/// /api/v1/text/chat endpoint.
+// #[cfg_attr(test, derive(PartialEq))]
+#[derive(Debug, Serialize)]
+pub struct ChatDetectionRequest {
+ /// Chat messages to run detection on
+ pub messages: Vec,
+
+ /// Detector parameters (available parameters depend on the detector)
+ pub detector_params: DetectorParams,
+}
+
+impl ChatDetectionRequest {
+ pub fn new(messages: Vec, detector_params: DetectorParams) -> Self {
+ Self {
+ messages,
+ detector_params,
+ }
+ }
+}
diff --git a/src/clients/detector/text_contents.rs b/src/clients/detector/text_contents.rs
new file mode 100644
index 00000000..d1015b1d
--- /dev/null
+++ b/src/clients/detector/text_contents.rs
@@ -0,0 +1,182 @@
+/*
+ Copyright FMS Guardrails Orchestrator Authors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+*/
+
+use async_trait::async_trait;
+use hyper::{HeaderMap, StatusCode};
+use serde::{Deserialize, Serialize};
+use tracing::instrument;
+
+use super::{post_with_headers, DetectorError, DEFAULT_PORT};
+use crate::{
+ clients::{create_http_client, Client, Error, HttpClient},
+ config::ServiceConfig,
+ health::HealthCheckResult,
+ models::DetectorParams,
+};
+
+#[cfg_attr(test, faux::create)]
+#[derive(Clone)]
+pub struct TextContentsDetectorClient {
+ client: HttpClient,
+ health_client: Option,
+}
+
+#[cfg_attr(test, faux::methods)]
+impl TextContentsDetectorClient {
+ pub async fn new(config: &ServiceConfig, health_config: Option<&ServiceConfig>) -> Self {
+ let client = create_http_client(DEFAULT_PORT, config).await;
+ let health_client = if let Some(health_config) = health_config {
+ Some(create_http_client(DEFAULT_PORT, health_config).await)
+ } else {
+ None
+ };
+ Self {
+ client,
+ health_client,
+ }
+ }
+
+ #[instrument(skip_all, fields(model_id))]
+ pub async fn text_contents(
+ &self,
+ model_id: &str,
+ request: ContentAnalysisRequest,
+ headers: HeaderMap,
+ ) -> Result>, Error> {
+ let url = self
+ .client
+ .base_url()
+ .join("/api/v1/text/contents")
+ .unwrap();
+ let response =
+ post_with_headers(self.client.clone(), url, request, headers, model_id).await?;
+ if response.status() == StatusCode::OK {
+ Ok(response.json().await?)
+ } else {
+ let code = response.status().as_u16();
+ let error = response
+ .json::()
+ .await
+ .unwrap_or(DetectorError {
+ code,
+ message: "".into(),
+ });
+ Err(error.into())
+ }
+ }
+}
+
+#[cfg_attr(test, faux::methods)]
+#[async_trait]
+impl Client for TextContentsDetectorClient {
+ fn name(&self) -> &str {
+ "text_contents_detector"
+ }
+
+ async fn health(&self) -> HealthCheckResult {
+ if let Some(health_client) = &self.health_client {
+ health_client.health().await
+ } else {
+ self.client.health().await
+ }
+ }
+}
+
+/// Request for text content analysis
+/// Results of this request will contain analysis / detection of each of the provided documents
+/// in the order they are present in the `contents` object.
+#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
+pub struct ContentAnalysisRequest {
+ /// Field allowing users to provide list of documents for analysis
+ pub contents: Vec,
+
+ /// Detector parameters (available parameters depend on the detector)
+ pub detector_params: DetectorParams,
+}
+
+impl ContentAnalysisRequest {
+ pub fn new(contents: Vec, detector_params: DetectorParams) -> ContentAnalysisRequest {
+ ContentAnalysisRequest {
+ contents,
+ detector_params,
+ }
+ }
+}
+
+/// Response of text content analysis endpoint
+#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
+pub struct ContentAnalysisResponse {
+ /// Start index of detection
+ pub start: usize,
+ /// End index of detection
+ pub end: usize,
+ /// Text corresponding to detection
+ pub text: String,
+ /// Relevant detection class
+ pub detection: String,
+ /// Detection type or aggregate detection label
+ pub detection_type: String,
+ /// Score of detection
+ pub score: f64,
+ /// Optional, any applicable evidence for detection
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub evidence: Option>,
+}
+
+impl From for crate::models::TokenClassificationResult {
+ fn from(value: ContentAnalysisResponse) -> Self {
+ Self {
+ start: value.start as u32,
+ end: value.end as u32,
+ word: value.text,
+ entity: value.detection,
+ entity_group: value.detection_type,
+ score: value.score,
+ token_count: None,
+ }
+ }
+}
+
+/// Evidence
+#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
+pub struct Evidence {
+ /// Evidence name
+ pub name: String,
+ /// Optional, evidence value
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub value: Option,
+ /// Optional, score for evidence
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub score: Option,
+}
+
+/// Evidence in response
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
+pub struct EvidenceObj {
+ /// Evidence name
+ pub name: String,
+ /// Optional, evidence value
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub value: Option,
+ /// Optional, score for evidence
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub score: Option,
+ /// Optional, evidence on evidence value
+ // Evidence nesting should likely not go beyond this
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub evidence: Option>,
+}
diff --git a/src/clients/detector/text_context_doc.rs b/src/clients/detector/text_context_doc.rs
new file mode 100644
index 00000000..e56c1619
--- /dev/null
+++ b/src/clients/detector/text_context_doc.rs
@@ -0,0 +1,140 @@
+/*
+ Copyright FMS Guardrails Orchestrator Authors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+*/
+
+use async_trait::async_trait;
+use hyper::{HeaderMap, StatusCode};
+use serde::{Deserialize, Serialize};
+use tracing::instrument;
+
+use super::{post_with_headers, DetectorError, DEFAULT_PORT};
+use crate::{
+ clients::{create_http_client, Client, Error, HttpClient},
+ config::ServiceConfig,
+ health::HealthCheckResult,
+ models::{DetectionResult, DetectorParams},
+};
+
+#[cfg_attr(test, faux::create)]
+#[derive(Clone)]
+pub struct TextContextDocDetectorClient {
+ client: HttpClient,
+ health_client: Option,
+}
+
+#[cfg_attr(test, faux::methods)]
+impl TextContextDocDetectorClient {
+ pub async fn new(config: &ServiceConfig, health_config: Option<&ServiceConfig>) -> Self {
+ let client = create_http_client(DEFAULT_PORT, config).await;
+ let health_client = if let Some(health_config) = health_config {
+ Some(create_http_client(DEFAULT_PORT, health_config).await)
+ } else {
+ None
+ };
+ Self {
+ client,
+ health_client,
+ }
+ }
+
+ #[instrument(skip_all, fields(model_id))]
+ pub async fn text_context_doc(
+ &self,
+ model_id: &str,
+ request: ContextDocsDetectionRequest,
+ headers: HeaderMap,
+ ) -> Result, Error> {
+ let url = self
+ .client
+ .base_url()
+ .join("/api/v1/text/context/doc")
+ .unwrap();
+ let response =
+ post_with_headers(self.client.clone(), url, request, headers, model_id).await?;
+ if response.status() == StatusCode::OK {
+ Ok(response.json().await?)
+ } else {
+ let code = response.status().as_u16();
+ let error = response
+ .json::()
+ .await
+ .unwrap_or(DetectorError {
+ code,
+ message: "".into(),
+ });
+ Err(error.into())
+ }
+ }
+}
+
+#[cfg_attr(test, faux::methods)]
+#[async_trait]
+impl Client for TextContextDocDetectorClient {
+ fn name(&self) -> &str {
+ "text_context_doc_detector"
+ }
+
+ async fn health(&self) -> HealthCheckResult {
+ if let Some(health_client) = &self.health_client {
+ health_client.health().await
+ } else {
+ self.client.health().await
+ }
+ }
+}
+
+/// A struct representing a request to a detector compatible with the
+/// /api/v1/text/context/doc endpoint.
+#[cfg_attr(test, derive(PartialEq))]
+#[derive(Debug, Serialize)]
+pub struct ContextDocsDetectionRequest {
+ /// Content to run detection on
+ pub content: String,
+
+ /// Type of context being sent
+ pub context_type: ContextType,
+
+ /// Context to run detection on
+ pub context: Vec,
+
+ /// Detector parameters (available parameters depend on the detector)
+ pub detector_params: DetectorParams,
+}
+
+/// Enum representing the context type of a detection
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
+pub enum ContextType {
+ #[serde(rename = "docs")]
+ Document,
+ #[serde(rename = "url")]
+ Url,
+}
+
+impl ContextDocsDetectionRequest {
+ pub fn new(
+ content: String,
+ context_type: ContextType,
+ context: Vec,
+ detector_params: DetectorParams,
+ ) -> Self {
+ Self {
+ content,
+ context_type,
+ context,
+ detector_params,
+ }
+ }
+}
diff --git a/src/clients/detector/text_generation.rs b/src/clients/detector/text_generation.rs
new file mode 100644
index 00000000..5a63d6c9
--- /dev/null
+++ b/src/clients/detector/text_generation.rs
@@ -0,0 +1,122 @@
+/*
+ Copyright FMS Guardrails Orchestrator Authors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+*/
+
+use async_trait::async_trait;
+use hyper::{HeaderMap, StatusCode};
+use serde::Serialize;
+use tracing::instrument;
+
+use super::{post_with_headers, DetectorError, DEFAULT_PORT};
+use crate::{
+ clients::{create_http_client, Client, Error, HttpClient},
+ config::ServiceConfig,
+ health::HealthCheckResult,
+ models::{DetectionResult, DetectorParams},
+};
+
+#[cfg_attr(test, faux::create)]
+#[derive(Clone)]
+pub struct TextGenerationDetectorClient {
+ client: HttpClient,
+ health_client: Option,
+}
+
+#[cfg_attr(test, faux::methods)]
+impl TextGenerationDetectorClient {
+ pub async fn new(config: &ServiceConfig, health_config: Option<&ServiceConfig>) -> Self {
+ let client = create_http_client(DEFAULT_PORT, config).await;
+ let health_client = if let Some(health_config) = health_config {
+ Some(create_http_client(DEFAULT_PORT, health_config).await)
+ } else {
+ None
+ };
+ Self {
+ client,
+ health_client,
+ }
+ }
+
+ #[instrument(skip_all, fields(model_id))]
+ pub async fn text_generation(
+ &self,
+ model_id: &str,
+ request: GenerationDetectionRequest,
+ headers: HeaderMap,
+ ) -> Result, Error> {
+ let url = self
+ .client
+ .base_url()
+ .join("/api/v1/text/generation")
+ .unwrap();
+ let response =
+ post_with_headers(self.client.clone(), url, request, headers, model_id).await?;
+ if response.status() == StatusCode::OK {
+ Ok(response.json().await?)
+ } else {
+ let code = response.status().as_u16();
+ let error = response
+ .json::()
+ .await
+ .unwrap_or(DetectorError {
+ code,
+ message: "".into(),
+ });
+ Err(error.into())
+ }
+ }
+}
+
+#[cfg_attr(test, faux::methods)]
+#[async_trait]
+impl Client for TextGenerationDetectorClient {
+ fn name(&self) -> &str {
+ "text_context_doc_detector"
+ }
+
+ async fn health(&self) -> HealthCheckResult {
+ if let Some(health_client) = &self.health_client {
+ health_client.health().await
+ } else {
+ self.client.health().await
+ }
+ }
+}
+
+/// A struct representing a request to a detector compatible with the
+/// /api/v1/text/generation endpoint.
+#[cfg_attr(test, derive(PartialEq))]
+#[derive(Debug, Serialize)]
+pub struct GenerationDetectionRequest {
+ /// User prompt sent to LLM
+ pub prompt: String,
+
+ /// Text generated from an LLM
+ pub generated_text: String,
+
+ /// Detector parameters (available parameters depend on the detector)
+ pub detector_params: DetectorParams,
+}
+
+impl GenerationDetectionRequest {
+ pub fn new(prompt: String, generated_text: String, detector_params: DetectorParams) -> Self {
+ Self {
+ prompt,
+ generated_text,
+ detector_params,
+ }
+ }
+}
diff --git a/src/clients/errors.rs b/src/clients/errors.rs
new file mode 100644
index 00000000..e630e21a
--- /dev/null
+++ b/src/clients/errors.rs
@@ -0,0 +1,96 @@
+/*
+ Copyright FMS Guardrails Orchestrator Authors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+*/
+
+use std::error::Error as _;
+
+use hyper::StatusCode;
+use tracing::error;
+
+/// Client errors.
+#[derive(Debug, Clone, PartialEq, thiserror::Error)]
+pub enum Error {
+ #[error("{}", .message)]
+ Grpc { code: StatusCode, message: String },
+ #[error("{}", .message)]
+ Http { code: StatusCode, message: String },
+ #[error("model not found: {model_id}")]
+ ModelNotFound { model_id: String },
+}
+
+impl Error {
+ /// Returns status code.
+ pub fn status_code(&self) -> StatusCode {
+ match self {
+ // Return equivalent http status code for grpc status code
+ Error::Grpc { code, .. } => *code,
+ // Return http status code for error responses
+ // and 500 for other errors
+ Error::Http { code, .. } => *code,
+ // Return 404 for model not found
+ Error::ModelNotFound { .. } => StatusCode::NOT_FOUND,
+ }
+ }
+}
+
+impl From for Error {
+ fn from(value: reqwest::Error) -> Self {
+ // Log lower level source of error.
+ // Examples:
+ // 1. client error (Connect) // Cases like connection error, wrong port etc.
+ // 2. client error (SendRequest) // Cases like cert issues
+ error!(
+ "http request failed. Source: {}",
+ value.source().unwrap().to_string()
+ );
+ // Return http status code for error responses
+ // and 500 for other errors
+ let code = match value.status() {
+ Some(code) => code,
+ None => StatusCode::INTERNAL_SERVER_ERROR,
+ };
+ Self::Http {
+ code,
+ message: value.to_string(),
+ }
+ }
+}
+
+impl From for Error {
+ fn from(value: tonic::Status) -> Self {
+ Self::Grpc {
+ code: grpc_to_http_code(value.code()),
+ message: value.message().to_string(),
+ }
+ }
+}
+
+/// Returns equivalent http status code for grpc status code
+pub fn grpc_to_http_code(value: tonic::Code) -> StatusCode {
+ use tonic::Code::*;
+ match value {
+ InvalidArgument => StatusCode::BAD_REQUEST,
+ Internal => StatusCode::INTERNAL_SERVER_ERROR,
+ NotFound => StatusCode::NOT_FOUND,
+ DeadlineExceeded => StatusCode::REQUEST_TIMEOUT,
+ Unimplemented => StatusCode::NOT_IMPLEMENTED,
+ Unauthenticated => StatusCode::UNAUTHORIZED,
+ PermissionDenied => StatusCode::FORBIDDEN,
+ Unavailable => StatusCode::SERVICE_UNAVAILABLE,
+ Ok => StatusCode::OK,
+ _ => StatusCode::INTERNAL_SERVER_ERROR,
+ }
+}
diff --git a/src/clients/generation.rs b/src/clients/generation.rs
index d599d8c1..c5068dc1 100644
--- a/src/clients/generation.rs
+++ b/src/clients/generation.rs
@@ -15,15 +15,14 @@
*/
-use std::collections::HashMap;
-
+use async_trait::async_trait;
use futures::{StreamExt, TryStreamExt};
use hyper::HeaderMap;
-use tracing::debug;
+use tracing::{debug, instrument};
-use super::{BoxStream, Error, NlpClient, TgisClient};
+use super::{BoxStream, Client, Error, NlpClient, TgisClient, NlpClientHttp};
use crate::{
- health::{HealthCheckResult, HealthProbe},
+ health::HealthCheckResult,
models::{
ClassifiedGeneratedTextResult, ClassifiedGeneratedTextStreamResult,
GuardrailsTextGenerationParameters,
@@ -40,7 +39,7 @@ use crate::{
},
};
-#[cfg_attr(test, faux::create, derive(Default))]
+#[cfg_attr(test, faux::create)]
#[derive(Clone)]
pub struct GenerationClient(Option);
@@ -48,24 +47,7 @@ pub struct GenerationClient(Option);
enum GenerationClientInner {
Tgis(TgisClient),
Nlp(NlpClient),
-}
-
-#[cfg_attr(test, faux::methods)]
-impl HealthProbe for GenerationClient {
- async fn health(&self) -> Result, Error> {
- match &self.0 {
- Some(GenerationClientInner::Tgis(client)) => client.health().await,
- Some(GenerationClientInner::Nlp(client)) => client.health().await,
- None => Ok(HashMap::new()),
- }
- }
-}
-
-#[cfg(test)]
-impl Default for GenerationClientInner {
- fn default() -> Self {
- Self::Tgis(TgisClient::default())
- }
+ NlpHttp(NlpHttpClient),
}
#[cfg_attr(test, faux::methods)]
@@ -78,10 +60,15 @@ impl GenerationClient {
Self(Some(GenerationClientInner::Nlp(client)))
}
+ pub fn nlp_http(client: NlpClientHttp) -> Self {
+ Self(Some(GenerationClientInner::NlpHttp(client)))
+ }
+
pub fn not_configured() -> Self {
Self(None)
}
+ #[instrument(skip_all, fields(model_id))]
pub async fn tokenize(
&self,
model_id: String,
@@ -97,19 +84,33 @@ impl GenerationClient {
return_offsets: false,
truncate_input_tokens: 0,
};
- debug!(%model_id, provider = "tgis", ?request, "sending tokenize request");
+ debug!(provider = "tgis", ?request, "sending tokenize request");
let mut response = client.tokenize(request, headers).await?;
- debug!(%model_id, provider = "tgis", ?response, "received tokenize response");
+ debug!(provider = "tgis", ?response, "received tokenize response");
let response = response.responses.swap_remove(0);
Ok((response.token_count, response.tokens))
}
Some(GenerationClientInner::Nlp(client)) => {
let request = TokenizationTaskRequest { text };
- debug!(%model_id, provider = "nlp", ?request, "sending tokenize request");
+ debug!(provider = "nlp", ?request, "sending tokenize request");
+ let response = client
+ .tokenization_task_predict(&model_id, request, headers)
+ .await?;
+ debug!(provider = "nlp", ?response, "received tokenize response");
+ let tokens = response
+ .results
+ .into_iter()
+ .map(|token| token.text)
+ .collect::>();
+ Ok((response.token_count as u32, tokens))
+ }
+ Some(GenerationClientInner::NlpHttp(client)) => {
+ let request = TokenizationTaskRequest { text };
+ debug!(provider = "nlp-http", ?request, "sending tokenize request");
let response = client
.tokenization_task_predict(&model_id, request, headers)
.await?;
- debug!(%model_id, provider = "nlp", ?response, "received tokenize response");
+ debug!(provider = "nlp-http", ?response, "received tokenize response");
let tokens = response
.results
.into_iter()
@@ -121,6 +122,7 @@ impl GenerationClient {
}
}
+ #[instrument(skip_all, fields(model_id))]
pub async fn generate(
&self,
model_id: String,
@@ -137,9 +139,9 @@ impl GenerationClient {
requests: vec![GenerationRequest { text }],
params,
};
- debug!(%model_id, provider = "tgis", ?request, "sending generate request");
+ debug!(provider = "tgis", ?request, "sending generate request");
let response = client.generate(request, headers).await?;
- debug!(%model_id, provider = "tgis", ?response, "received generate response");
+ debug!(provider = "tgis", ?response, "received generate response");
Ok(response.into())
}
Some(GenerationClientInner::Nlp(client)) => {
@@ -174,17 +176,58 @@ impl GenerationClient {
..Default::default()
}
};
- debug!(%model_id, provider = "nlp", ?request, "sending generate request");
+ debug!(provider = "nlp", ?request, "sending generate request");
let response = client
.text_generation_task_predict(&model_id, request, headers)
.await?;
- debug!(%model_id, provider = "nlp", ?response, "received generate response");
+ debug!(provider = "nlp", ?response, "received generate response");
Ok(response.into())
}
+ Some(GenerationClientInner::NlpHttp(client)) => {
+ let request = if let Some(params) = params {
+ TextGenerationTaskRequest {
+ text,
+ max_new_tokens: params.max_new_tokens.map(|v| v as i64),
+ min_new_tokens: params.min_new_tokens.map(|v| v as i64),
+ truncate_input_tokens: params.truncate_input_tokens.map(|v| v as i64),
+ decoding_method: params.decoding_method,
+ top_k: params.top_k.map(|v| v as i64),
+ top_p: params.top_p,
+ typical_p: params.typical_p,
+ temperature: params.temperature,
+ repetition_penalty: params.repetition_penalty,
+ max_time: params.max_time,
+ exponential_decay_length_penalty: params
+ .exponential_decay_length_penalty
+ .map(Into::into),
+ stop_sequences: params.stop_sequences.unwrap_or_default(),
+ seed: params.seed.map(|v| v as u64),
+ preserve_input_text: params.preserve_input_text,
+ input_tokens: params.input_tokens,
+ generated_tokens: params.generated_tokens,
+ token_logprobs: params.token_logprobs,
+ token_ranks: params.token_ranks,
+ include_stop_sequence: params.include_stop_sequence,
+ }
+ } else {
+ TextGenerationTaskRequest {
+ text,
+ ..Default::default()
+ }
+ }
+ debug!(provider = "nlp-http", ?request, "sending generate request");
+ let response = client
+ .text_generation_task_predict(&model_id, request, headers)
+ .await?;
+ debug!(provider = "nlp-http", ?response, "received generate response");
+ Ok(response.into())
+ }
+ };
None => Err(Error::ModelNotFound { model_id }),
}
}
+ #[instrument(skip_all, fields(model_id))]
pub async fn generate_stream(
&self,
model_id: String,
@@ -201,7 +244,11 @@ impl GenerationClient {
request: Some(GenerationRequest { text }),
params,
};
- debug!(%model_id, provider = "tgis", ?request, "sending generate_stream request");
+ debug!(
+ provider = "tgis",
+ ?request,
+ "sending generate_stream request"
+ );
let response_stream = client
.generate_stream(request, headers)
.await?
@@ -241,7 +288,11 @@ impl GenerationClient {
..Default::default()
}
};
- debug!(%model_id, provider = "nlp", ?request, "sending generate_stream request");
+ debug!(
+ provider = "nlp",
+ ?request,
+ "sending generate_stream request"
+ );
let response_stream = client
.server_streaming_text_generation_task_predict(&model_id, request, headers)
.await?
@@ -249,7 +300,67 @@ impl GenerationClient {
.boxed();
Ok(response_stream)
}
+ Some(GenerationClientInner::NlpHttp(client)) => {
+` let request = if let Some(params) = params {
+ ServerStreamingTextGenerationTaskRequest{
+ text,
+ max_new_tokens: params.max_new_tokens.map(|v| v as i64),
+ min_new_tokens: params.min_new_tokens.map(|v| v as i64),
+ truncate_input_tokens: params.truncate_input_tokens.map(|v| v as i64),
+ decoding_method: params.decoding_method,
+ top_k: params.top_k.map(|v| v as i64),
+ top_p: params.top_p,
+ typical_p: params.typical_p,
+ temperature: params.temperature,
+ repetition_penalty: params.repetition_penalty,
+ max_time: params.max_time,
+ exponential_decay_length_penalty: params
+ .exponential_decay_length_penalty
+ .map(Into::into),
+ stop_sequences: params.stop_sequences.unwrap_or_default(),
+ seed: params.seed.map(|v| v as u64),
+ preserve_input_text: params.preserve_input_text,
+ input_tokens: params.input_tokens,
+ generated_tokens: params.generated_tokens,
+ token_logprobs: params.token_logprobs,
+ token_ranks: params.token_ranks,
+ include_stop_sequence: params.include_stop_sequence,
+ }
+ } else {
+ ServerStreamingTextGenerationTaskRequest {
+ text,
+ ..Default::default()
+ }
+ };
+ debug!(
+ provider = "nlp-http",
+ ?request,
+ "sending generate_stream request"
+ );
+ let response_stream = client
+ .server_streaming_text_generation_task_predict(&model_id, request, headers)
+ .await?
+ .map_ok(Into::into)
+ .boxed();
+ Ok(response_stream)
+ }
None => Err(Error::ModelNotFound { model_id }),
}
}
+
+#[cfg_attr(test, faux::methods)]
+#[async_trait]
+impl Client for GenerationClient {
+ fn name(&self) -> &str {
+ "generation"
+ }
+
+ async fn health(&self) -> HealthCheckResult {
+ match &self.0 {
+ Some(GenerationClientInner::Tgis(client)) => client.health().await,
+ Some(GenerationClientInner::Nlp(client)) => client.health().await,
+ Some(GenerationClientInner::NlpHttp(client)) => client.health().await,
+ None => unimplemented!(),
+ }
+ }
}
diff --git a/src/clients/http.rs b/src/clients/http.rs
new file mode 100644
index 00000000..862db811
--- /dev/null
+++ b/src/clients/http.rs
@@ -0,0 +1,156 @@
+/*
+ Copyright FMS Guardrails Orchestrator Authors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+*/
+
+use hyper::StatusCode;
+use reqwest::Response;
+use tracing::error;
+use url::Url;
+
+use crate::health::{HealthCheckResult, HealthStatus, OptionalHealthCheckResponseBody};
+
+#[derive(Clone)]
+pub struct HttpClient {
+ base_url: Url,
+ health_url: Url,
+ client: reqwest::Client,
+}
+
+impl HttpClient {
+ pub fn new(base_url: Url, client: reqwest::Client) -> Self {
+ let health_url = base_url.join("health").unwrap();
+ Self {
+ base_url,
+ health_url,
+ client,
+ }
+ }
+
+ pub fn base_url(&self) -> &Url {
+ &self.base_url
+ }
+
+ /// This is sectioned off to allow for testing.
+ pub(super) async fn http_response_to_health_check_result(
+ res: Result,
+ ) -> HealthCheckResult {
+ match res {
+ Ok(response) => {
+ if response.status() == StatusCode::OK {
+ if let Ok(body) = response.json::().await {
+ // If the service provided a body, we only anticipate a minimal health status and optional reason.
+ HealthCheckResult {
+ status: body.status.clone(),
+ code: StatusCode::OK,
+ reason: match body.status {
+ HealthStatus::Healthy => None,
+ _ => body.reason,
+ },
+ }
+ } else {
+ // If the service did not provide a body, we assume it is healthy.
+ HealthCheckResult {
+ status: HealthStatus::Healthy,
+ code: StatusCode::OK,
+ reason: None,
+ }
+ }
+ } else {
+ HealthCheckResult {
+ // The most we can presume is that 5xx errors are likely indicating service issues, implying the service is unhealthy.
+ // and that 4xx errors are more likely indicating health check failures, i.e. due to configuration/implementation issues.
+ // Regardless we can't be certain, so the reason is also provided.
+ // TODO: We will likely circle back to re-evaluate this logic in the future
+ // when we know more about how the client health results will be used.
+ status: if response.status().as_u16() >= 500
+ && response.status().as_u16() < 600
+ {
+ HealthStatus::Unhealthy
+ } else if response.status().as_u16() >= 400
+ && response.status().as_u16() < 500
+ {
+ HealthStatus::Unknown
+ } else {
+ error!(
+ "unexpected http health check status code: {}",
+ response.status()
+ );
+ HealthStatus::Unknown
+ },
+ code: response.status(),
+ reason: response.status().canonical_reason().map(|v| v.to_string()),
+ }
+ }
+ }
+ Err(e) => {
+ error!("error checking health: {}", e);
+ HealthCheckResult {
+ status: HealthStatus::Unknown,
+ code: e.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
+ reason: Some(e.to_string()),
+ }
+ }
+ }
+ }
+
+ pub async fn health(&self) -> HealthCheckResult {
+ let res = self.get(self.health_url.clone()).send().await;
+ Self::http_response_to_health_check_result(res).await
+ }
+}
+
+impl std::ops::Deref for HttpClient {
+ type Target = reqwest::Client;
+
+ fn deref(&self) -> &Self::Target {
+ &self.client
+ }
+}
+
+/// Extracts a base url from a url including path segments.
+pub fn extract_base_url(url: &Url) -> Option {
+ let mut url = url.clone();
+ match url.path_segments_mut() {
+ Ok(mut path) => {
+ path.clear();
+ }
+ Err(_) => {
+ return None;
+ }
+ }
+ Some(url)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_extract_base_url() {
+ let url =
+ Url::parse("https://example-detector.route.example.com/api/v1/text/contents").unwrap();
+ let base_url = extract_base_url(&url);
+ assert_eq!(
+ Some(Url::parse("https://example-detector.route.example.com/").unwrap()),
+ base_url
+ );
+ let health_url = base_url.map(|v| v.join("/health").unwrap());
+ assert_eq!(
+ Some(Url::parse("https://example-detector.route.example.com/health").unwrap()),
+ health_url
+ );
+ }
+}
diff --git a/src/clients/nlp.rs b/src/clients/nlp.rs
index 90b3f693..01067019 100644
--- a/src/clients/nlp.rs
+++ b/src/clients/nlp.rs
@@ -15,18 +15,20 @@
*/
-use std::collections::HashMap;
-
-use axum::http::{Extensions, HeaderMap};
+use async_trait::async_trait;
+use axum::http::HeaderMap;
use futures::{StreamExt, TryStreamExt};
use ginepro::LoadBalancedChannel;
-use tonic::{metadata::MetadataMap, Request};
+use tonic::{Code, Request};
+use tracing::{info, instrument};
-use super::{create_grpc_clients, BoxStream, Error};
+use super::{
+ create_grpc_client, errors::grpc_to_http_code, grpc_request_with_headers, BoxStream, Client,
+ Error,
+};
use crate::{
- clients::COMMON_ROUTER_KEY,
config::ServiceConfig,
- health::{HealthCheckResult, HealthProbe},
+ health::{HealthCheckResult, HealthStatus},
pb::{
caikit::runtime::nlp::{
nlp_service_client::NlpServiceClient, ServerStreamingTextGenerationTaskRequest,
@@ -38,122 +40,130 @@ use crate::{
},
grpc::health::v1::{health_client::HealthClient, HealthCheckRequest},
},
+ tracing_utils::trace_context_from_grpc_response,
};
+const DEFAULT_PORT: u16 = 8085;
const MODEL_ID_HEADER_NAME: &str = "mm-model-id";
-#[cfg_attr(test, faux::create, derive(Default))]
+#[cfg_attr(test, faux::create)]
#[derive(Clone)]
pub struct NlpClient {
- clients: HashMap>,
- health_clients: HashMap>,
-}
-
-#[cfg_attr(test, faux::methods)]
-impl HealthProbe for NlpClient {
- async fn health(&self) -> Result, Error> {
- let mut results = HashMap::with_capacity(self.health_clients.len());
- for (model_id, mut client) in self.health_clients.clone() {
- results.insert(
- model_id.clone(),
- client
- .check(HealthCheckRequest {
- service: model_id.clone(),
- })
- .await
- .into(),
- );
- }
- Ok(results)
- }
+ client: NlpServiceClient,
+ health_client: HealthClient,
}
#[cfg_attr(test, faux::methods)]
impl NlpClient {
- pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self {
- let clients = create_grpc_clients(default_port, config, NlpServiceClient::new).await;
- let health_clients = create_grpc_clients(default_port, config, HealthClient::new).await;
+ pub async fn new(config: &ServiceConfig) -> Self {
+ let client = create_grpc_client(DEFAULT_PORT, config, NlpServiceClient::new).await;
+ let health_client = create_grpc_client(DEFAULT_PORT, config, HealthClient::new).await;
Self {
- clients,
- health_clients,
+ client,
+ health_client,
}
}
- fn client(&self, _model_id: &str) -> Result, Error> {
- // NOTE: We currently forward requests to common router, so we use a single client.
- let model_id = COMMON_ROUTER_KEY;
- Ok(self
- .clients
- .get(model_id)
- .ok_or_else(|| Error::ModelNotFound {
- model_id: model_id.to_string(),
- })?
- .clone())
- }
-
+ #[instrument(skip_all, fields(model_id))]
pub async fn tokenization_task_predict(
&self,
model_id: &str,
request: TokenizationTaskRequest,
headers: HeaderMap,
) -> Result {
- let request = request_with_model_id(request, model_id, headers);
- Ok(self
- .client(model_id)?
- .tokenization_task_predict(request)
- .await?
- .into_inner())
+ let mut client = self.client.clone();
+ let request = request_with_headers(request, model_id, headers);
+ info!(?request, "sending request to NLP gRPC service");
+ let response = client.tokenization_task_predict(request).await?;
+ trace_context_from_grpc_response(&response);
+ Ok(response.into_inner())
}
+ #[instrument(skip_all, fields(model_id))]
pub async fn token_classification_task_predict(
&self,
model_id: &str,
request: TokenClassificationTaskRequest,
headers: HeaderMap,
) -> Result {
- let request = request_with_model_id(request, model_id, headers);
- Ok(self
- .client(model_id)?
- .token_classification_task_predict(request)
- .await?
- .into_inner())
+ let mut client = self.client.clone();
+ let request = request_with_headers(request, model_id, headers);
+ info!(?request, "sending request to NLP gRPC service");
+ let response = client.token_classification_task_predict(request).await?;
+ trace_context_from_grpc_response(&response);
+ Ok(response.into_inner())
}
+ #[instrument(skip_all, fields(model_id))]
pub async fn text_generation_task_predict(
&self,
model_id: &str,
request: TextGenerationTaskRequest,
headers: HeaderMap,
) -> Result {
- let request = request_with_model_id(request, model_id, headers);
- Ok(self
- .client(model_id)?
- .text_generation_task_predict(request)
- .await?
- .into_inner())
+ let mut client = self.client.clone();
+ let request = request_with_headers(request, model_id, headers);
+ info!(?request, "sending request to NLP gRPC service");
+ let response = client.text_generation_task_predict(request).await?;
+ trace_context_from_grpc_response(&response);
+ Ok(response.into_inner())
}
+ #[instrument(skip_all, fields(model_id))]
pub async fn server_streaming_text_generation_task_predict(
&self,
model_id: &str,
request: ServerStreamingTextGenerationTaskRequest,
headers: HeaderMap,
) -> Result>, Error> {
- let request = request_with_model_id(request, model_id, headers);
- let response_stream = self
- .client(model_id)?
+ let mut client = self.client.clone();
+ let request = request_with_headers(request, model_id, headers);
+ info!(?request, "sending stream request to NLP gRPC service");
+ let response = client
.server_streaming_text_generation_task_predict(request)
- .await?
- .into_inner()
- .map_err(Into::into)
- .boxed();
+ .await?;
+ trace_context_from_grpc_response(&response);
+ let response_stream = response.into_inner().map_err(Into::into).boxed();
Ok(response_stream)
}
}
-fn request_with_model_id(request: T, model_id: &str, headers: HeaderMap) -> Request {
- let metadata = MetadataMap::from_headers(headers);
- let mut request = Request::from_parts(metadata, Extensions::new(), request);
+#[cfg_attr(test, faux::methods)]
+#[async_trait]
+impl Client for NlpClient {
+ fn name(&self) -> &str {
+ "nlp"
+ }
+
+ async fn health(&self) -> HealthCheckResult {
+ let mut client = self.health_client.clone();
+ let response = client
+ .check(HealthCheckRequest { service: "".into() })
+ .await;
+ let code = match response {
+ Ok(_) => Code::Ok,
+ Err(status) if matches!(status.code(), Code::InvalidArgument | Code::NotFound) => {
+ Code::Ok
+ }
+ Err(status) => status.code(),
+ };
+ let status = if matches!(code, Code::Ok) {
+ HealthStatus::Healthy
+ } else {
+ HealthStatus::Unhealthy
+ };
+ HealthCheckResult {
+ status,
+ code: grpc_to_http_code(code),
+ reason: None,
+ }
+ }
+}
+
+/// Turns an NLP client gRPC request body of type `T` and headers into a `tonic::Request`.
+/// Also injects provided `model_id` and `traceparent` from current context into headers.
+fn request_with_headers(request: T, model_id: &str, headers: HeaderMap) -> Request {
+ let mut request = grpc_request_with_headers(request, headers);
request
.metadata_mut()
.insert(MODEL_ID_HEADER_NAME, model_id.parse().unwrap());
diff --git a/src/clients/nlp_http.rs b/src/clients/nlp_http.rs
new file mode 100644
index 00000000..7b4c3645
--- /dev/null
+++ b/src/clients/nlp_http.rs
@@ -0,0 +1,184 @@
+use async_trait::async_trait;
+use axum::extract::Extension;
+use tracing::{info, instrument};
+use hyper::{HeaderMap, StatusCode};
+use tracing::{info, instrument};
+
+use super::{
+ create_http_client, Client, Error, HttpClient
+};
+use crate::{
+ config::ServiceConfig,
+ health::HealthCheckResult,
+ pb::{
+ caikit::runtime::nlp::{
+ nlp_service_client::NlpServiceClient, ServerStreamingTextGenerationTaskRequest,
+ TextGenerationTaskRequest, TokenClassificationTaskRequest, TokenizationTaskRequest,
+ },
+ caikit_data_model::nlp::{
+ GeneratedTextResult, GeneratedTextStreamResult, TokenClassificationResults,
+ TokenizationResults,
+ },
+ },
+ tracing_utils::trace_context_from_http_response
+};
+
+const DEFAULT_PORT: u16 = 8085;
+const MODEL_ID_HEADER_NAME: &str = "mm-model-id";
+
+#[cfg_attr(test, faux::create)]
+#[derive(Clone)]
+pub struct NlpClientHttp {
+ client: HttpClient,
+ health_client: Option,
+}
+
+#[cfg_attr(test, faux::methods)]
+impl NlpClientHttp {
+ pub async fn new(config: &ServiceConfig) -> Self {
+ let client = create_http_client(DEFAULT_PORT, config);
+ let health_client = if let Some(health_config) = health_config {
+ Some(create_http_client(DEFAULT_PORT, health_config).await);
+ } else {
+ None
+ };
+ Self {
+ client,
+ health_client,
+ }
+ }
+
+ #[instrument(skip_all, fields(request.model))]
+ pub async tokenization_task_predict(
+ &self,
+ request: caikit::runtime::nlp::TokenizationTaskRequest,
+ headers: HeaderMap,
+ ) -> Result {
+ let url = self.client.base_url().join("/api/v1/task/tokenization").unwrap();
+ let headers = with_traceparent_header(headers);
+ let request - request_with_headers(request, headers);
+ info!(?request, "sending request to NLP http service");
+ let response = self
+ .client
+ .post(url)
+ .headers(headers)
+ .json(&request)
+ .send()
+ .await?;
+ match response.status() {
+ StatusCode::OK => OK(response.json().await?.into()),
+ _ => {
+ let code = response.status();
+ let message = if let Ok(response) = response.json().await() {
+ response.message
+ } else {
+ "unknown error occured".into()
+ };
+ Err(Error::Http {code, error})
+ }
+ }
+ }
+
+ pub async token_classification_task_predict(
+ &self,
+ request: TokenClassificationTaskRequest.
+ headers: HeaderMap,
+ ) -> Result {
+ let url = self.client.base_url().join("/api/v1/task/token-classification").unwrap();
+ let headers = with_traceparent_header(headers);
+ info!(?request, "sending request to NLP http service");
+ let response = self
+ .client
+ .post(url)
+ .headers(headers)
+ .json(&request)
+ .send()
+ .await?;
+ match response.status() {
+ StatusCode::OK => Ok(response.json::().await?.into()),
+ _ => {
+ let code = response.status();
+ let message = if let Ok(response) = response.json::().await {
+ response.message
+ } else {
+ "unknown error occured".into()
+ };
+ Err(Error::Http { code, message })
+ }
+ }
+
+ }
+
+ pub async text_generation_task_predict(
+ &self,
+ request: TextGenerationTaskRequest,
+ headers: HeaderMap,
+ ) -> Result {
+ let url = self.client.base_url().join("/api/v1/task/text-generation").unwrap();
+ let headers = with_traceparent_header(headers);
+ info!(?request, "sending request to NLP http service");
+ let response = self
+ .client
+ .post(url)
+ .headers(headers)
+ .json(&request)
+ .send()
+ .await?;
+ match response.status() {
+ StatusCode::OK => Ok(response.json::().await?.into()),
+ _ => {
+ let code = response.status();
+ let message = if let Ok(response) = response.json::().await {
+ response.message
+ } else {
+ "unknown error occured".into()
+ };
+ Err(Error::Http { code, message })
+ }
+ }
+ }
+
+ pub async server_streaming_text_generation_task_predict(
+ &self,
+ request, ServerStreamingTextGenerationTaskRequest,
+ headers: HeaderMap,
+ ) -> Result {
+ let url = self.client.base_url().join("/api/v1/task/streaming-text-generation").unwrap();
+ let headers = with_traceparent_header(headers);
+ info!(?request, "sending request to NLP http service");
+ let response = self
+ .client
+ .post(url)
+ .headers(headers)
+ .json(&request)
+ .send()
+ .await?;
+ match response.status() {
+ StatusCode::OK => Ok(response.json::().await?.into()),
+ _ => {
+ let code = response.status();
+ let message = if let Ok(response) = response.json::().await {
+ response.message
+ } else {
+ "unknown error occured".into()
+ };
+ Err(Error::Http { code, message })
+ }
+ }
+ }
+}
+
+#[cfg_attr(test, faux::create)]
+#[async_trait]
+impl Client for NlpClientHttp {
+ fn name(&self) -> &str {
+ "nlp_http"
+ }
+ async fn health(&self) -> HealthCheckResult {
+ if let Some(health_client) = &self.health_client {
+ health_client.health().await
+ } else {
+ self.client.health().await
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/clients/openai.rs b/src/clients/openai.rs
new file mode 100644
index 00000000..a626d111
--- /dev/null
+++ b/src/clients/openai.rs
@@ -0,0 +1,643 @@
+/*
+ Copyright FMS Guardrails Orchestrator Authors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+*/
+
+use std::{collections::HashMap, convert::Infallible};
+
+use async_trait::async_trait;
+use axum::response::sse;
+use futures::StreamExt;
+use hyper::{HeaderMap, StatusCode};
+use reqwest_eventsource::{Event, RequestBuilderExt};
+use serde::{Deserialize, Serialize};
+use tokio::sync::mpsc;
+use tracing::{info, instrument};
+
+use super::{create_http_client, Client, Error, HttpClient};
+use crate::{
+ config::ServiceConfig, health::HealthCheckResult, tracing_utils::with_traceparent_header,
+};
+
+const DEFAULT_PORT: u16 = 8080;
+
+#[cfg_attr(test, faux::create)]
+#[derive(Clone)]
+pub struct OpenAiClient {
+ client: HttpClient,
+ health_client: Option,
+}
+
+#[cfg_attr(test, faux::methods)]
+impl OpenAiClient {
+ pub async fn new(config: &ServiceConfig, health_config: Option<&ServiceConfig>) -> Self {
+ let client = create_http_client(DEFAULT_PORT, config).await;
+ let health_client = if let Some(health_config) = health_config {
+ Some(create_http_client(DEFAULT_PORT, health_config).await)
+ } else {
+ None
+ };
+ Self {
+ client,
+ health_client,
+ }
+ }
+
+ #[instrument(skip_all, fields(request.model))]
+ pub async fn chat_completions(
+ &self,
+ request: ChatCompletionsRequest,
+ headers: HeaderMap,
+ ) -> Result {
+ let url = self.client.base_url().join("/v1/chat/completions").unwrap();
+ let headers = with_traceparent_header(headers);
+ let stream = request.stream.unwrap_or_default();
+ info!(?url, ?headers, ?request, "sending client request");
+ if stream {
+ let (tx, rx) = mpsc::channel(32);
+ let mut event_stream = self
+ .client
+ .post(url)
+ .headers(headers)
+ .json(&request)
+ .eventsource()
+ .unwrap();
+ // Spawn task to forward events to receiver
+ tokio::spawn(async move {
+ while let Some(result) = event_stream.next().await {
+ match result {
+ Ok(event) => {
+ if let Event::Message(message) = event {
+ let event = sse::Event::default().data(message.data);
+ let _ = tx.send(Ok(event)).await;
+ }
+ }
+ Err(reqwest_eventsource::Error::StreamEnded) => break,
+ Err(error) => {
+ // We received an error from the event stream, send an error event
+ let event =
+ sse::Event::default().event("error").data(error.to_string());
+ let _ = tx.send(Ok(event)).await;
+ }
+ }
+ }
+ });
+ Ok(ChatCompletionsResponse::Streaming(rx))
+ } else {
+ let response = self
+ .client
+ .post(url)
+ .headers(headers)
+ .json(&request)
+ .send()
+ .await?;
+ match response.status() {
+ StatusCode::OK => Ok(response.json::().await?.into()),
+ _ => {
+ let code = response.status();
+ let message = if let Ok(response) = response.json::().await {
+ response.message
+ } else {
+ "unknown error occurred".into()
+ };
+ Err(Error::Http { code, message })
+ }
+ }
+ }
+ }
+}
+
+#[cfg_attr(test, faux::methods)]
+#[async_trait]
+impl Client for OpenAiClient {
+ fn name(&self) -> &str {
+ "openai"
+ }
+
+ async fn health(&self) -> HealthCheckResult {
+ if let Some(health_client) = &self.health_client {
+ health_client.health().await
+ } else {
+ self.client.health().await
+ }
+ }
+}
+
+#[derive(Debug)]
+pub enum ChatCompletionsResponse {
+ Unary(ChatCompletion),
+ Streaming(mpsc::Receiver>),
+}
+
+impl From for ChatCompletionsResponse {
+ fn from(value: ChatCompletion) -> Self {
+ Self::Unary(value)
+ }
+}
+
+#[derive(Debug, Default, Clone, Serialize, Deserialize)]
+pub struct ChatCompletionsRequest {
+ /// A list of messages comprising the conversation so far.
+ pub messages: Vec,
+ /// ID of the model to use.
+ pub model: String,
+ /// Whether or not to store the output of this chat completion request.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub store: Option,
+ /// Developer-defined tags and values.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub metadata: Option,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub frequency_penalty: Option,
+ /// Modify the likelihood of specified tokens appearing in the completion.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub logit_bias: Option>,
+ /// Whether to return log probabilities of the output tokens or not.
+ /// If true, returns the log probabilities of each output token returned in the content of message.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub logprobs: Option,
+ /// An integer between 0 and 20 specifying the number of most likely tokens to return
+ /// at each token position, each with an associated log probability.
+ /// logprobs must be set to true if this parameter is used.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub top_logprobs: Option,
+ /// The maximum number of tokens that can be generated in the chat completion. (DEPRECATED)
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub max_tokens: Option,
+ /// An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub max_completion_tokens: Option,
+ /// How many chat completion choices to generate for each input message.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub n: Option,
+ /// Positive values penalize new tokens based on whether they appear in the text so far,
+ /// increasing the model's likelihood to talk about new topics.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub presence_penalty: Option,
+ /// An object specifying the format that the model must output.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub response_format: Option,
+ /// If specified, our system will make a best effort to sample deterministically,
+ /// such that repeated requests with the same seed and parameters should return the same result.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub seed: Option,
+ /// Specifies the latency tier to use for processing the request.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub service_tier: Option,
+ /// Up to 4 sequences where the API will stop generating further tokens.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub stop: Option,
+ /// If set, partial message deltas will be sent, like in ChatGPT.
+ /// Tokens will be sent as data-only server-sent events as they become available,
+ /// with the stream terminated by a data: [DONE] message.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub stream: Option,
+ /// Options for streaming response. Only set this when you set stream: true.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub stream_options: Option,
+ /// What sampling temperature to use, between 0 and 2.
+ /// Higher values like 0.8 will make the output more random,
+ /// while lower values like 0.2 will make it more focused and deterministic.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub temperature: Option,
+ /// An alternative to sampling with temperature, called nucleus sampling,
+ /// where the model considers the results of the tokens with top_p probability mass.
+ /// So 0.1 means only the tokens comprising the top 10% probability mass are considered.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub top_p: Option,
+ /// A list of tools the model may call.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub tools: Option,
+ /// Controls which (if any) tool is called by the model.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub tool_choice: Option,
+ /// Whether to enable parallel function calling during tool use.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub parallel_tool_calls: Option,
+ /// A unique identifier representing your end-user.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub user: Option,
+
+ // Additional vllm params
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub best_of: Option,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub use_beam_search: Option,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub top_k: Option,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub min_p: Option,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub repetition_penalty: Option,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub length_penalty: Option,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub early_stopping: Option,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub ignore_eos: Option,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub min_tokens: Option,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub stop_token_ids: Option>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub skip_special_tokens: Option,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub spaces_between_special_tokens: Option,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ResponseFormat {
+ /// The type of response format being defined.
+ #[serde(rename = "type")]
+ pub r#type: String,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub json_schema: Option,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct JsonSchema {
+ /// The name of the response format.
+ pub name: String,
+ /// A description of what the response format is for, used by the model to determine how to respond in the format.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub description: Option,
+ /// The schema for the response format, described as a JSON Schema object.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub schema: Option,
+ /// Whether to enable strict schema adherence when generating the output.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub strict: Option,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct Tool {
+ /// The type of the tool.
+ #[serde(rename = "type")]
+ pub r#type: String,
+ pub function: ToolFunction,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ToolFunction {
+ /// The name of the function to be called.
+ pub name: String,
+ /// A description of what the function does, used by the model to choose when and how to call the function.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub description: Option,
+ /// The parameters the functions accepts, described as a JSON Schema object.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub parameters: Option,
+ /// Whether to enable strict schema adherence when generating the function call.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub strict: Option,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+#[serde(untagged)]
+pub enum ToolChoice {
+ /// `none` means the model will not call any tool and instead generates a message.
+ /// `auto` means the model can pick between generating a message or calling one or more tools.
+ /// `required` means the model must call one or more tools.
+ String,
+ /// Specifies a tool the model should use. Use to force the model to call a specific function.
+ Object(ToolChoiceObject),
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ToolChoiceObject {
+ /// The type of the tool.
+ #[serde(rename = "type")]
+ pub r#type: String,
+ pub function: Function,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct StreamOptions {
+ /// If set, an additional chunk will be streamed before the data: [DONE] message.
+ /// The usage field on this chunk shows the token usage statistics for the entire
+ /// request, and the choices field will always be an empty array. All other chunks
+ /// will also include a usage field, but with a null value.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub include_usage: Option,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct JsonSchemaObject {
+ pub id: String,
+ pub schema: String,
+ pub title: String,
+ pub description: Option,
+ #[serde(rename = "type")]
+ pub r#type: String,
+ pub properties: Option>,
+ pub required: Option>,
+}
+
+#[derive(Debug, Default, Clone, Serialize, Deserialize)]
+pub struct Message {
+ /// The role of the messages author.
+ pub role: String,
+ /// The contents of the message.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub content: Option,
+ /// An optional name for the participant.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub name: Option,
+ /// The refusal message by the assistant. (assistant message only)
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub refusal: Option,
+ /// The tool calls generated by the model, such as function calls. (assistant message only)
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub tool_calls: Option>,
+ /// Tool call that this message is responding to. (tool message only)
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub tool_call_id: Option,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+#[serde(untagged)]
+pub enum Content {
+ /// The text contents of the message.
+ Text(String),
+ /// Array of content parts.
+ Array(Vec),
+}
+
+impl From for Content {
+ fn from(value: String) -> Self {
+ Content::Text(value)
+ }
+}
+
+impl From<&str> for Content {
+ fn from(value: &str) -> Self {
+ Content::Text(value.to_string())
+ }
+}
+
+impl From> for Content {
+ fn from(value: Vec) -> Self {
+ Content::Array(value)
+ }
+}
+
+impl From for ContentPart {
+ fn from(value: String) -> Self {
+ ContentPart {
+ r#type: ContentType::Text,
+ text: Some(value),
+ image_url: None,
+ refusal: None,
+ }
+ }
+}
+
+impl From> for Content {
+ fn from(value: Vec) -> Self {
+ Content::Array(value.into_iter().map(|v| v.into()).collect())
+ }
+}
+
+#[derive(Debug, Default, Clone, Serialize, Deserialize)]
+pub enum ContentType {
+ #[serde(rename = "text")]
+ #[default]
+ Text,
+ #[serde(rename = "image_url")]
+ ImageUrl,
+}
+
+#[derive(Debug, Default, Clone, Serialize, Deserialize)]
+pub struct ContentPart {
+ /// The type of the content part.
+ #[serde(rename = "type")]
+ pub r#type: ContentType,
+ /// Text content
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub text: Option,
+ /// Image content
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub image_url: Option,
+ /// The refusal message generated by the model. (assistant message only)
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub refusal: Option,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ImageUrl {
+ /// Either a URL of the image or the base64 encoded image data.
+ pub url: String,
+ /// Specifies the detail level of the image.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub detail: Option,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ToolCall {
+ /// The ID of the tool call.
+ pub id: String,
+ /// The type of the tool.
+ #[serde(rename = "type")]
+ pub r#type: String,
+ /// The function that the model called.
+ pub function: Function,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct Function {
+ /// The name of the function to call.
+ pub name: String,
+ /// The arguments to call the function with, as generated by the model in JSON format.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub arguments: Option>,
+}
+
+/// Represents a chat completion response returned by model, based on the provided input.
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ChatCompletion {
+ /// A unique identifier for the chat completion.
+ pub id: String,
+ /// A list of chat completion choices. Can be more than one if n is greater than 1.
+ pub choices: Vec,
+ /// The Unix timestamp (in seconds) of when the chat completion was created.
+ pub created: i64,
+ /// The model used for the chat completion.
+ pub model: String,
+ /// The service tier used for processing the request.
+ /// This field is only included if the `service_tier` parameter is specified in the request.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub service_tier: Option,
+ /// This fingerprint represents the backend configuration that the model runs with.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub system_fingerprint: Option,
+ /// The object type, which is always `chat.completion`.
+ pub object: String,
+ /// Usage statistics for the completion request.
+ pub usage: Usage,
+}
+
+/// A chat completion choice.
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ChatCompletionChoice {
+ /// The reason the model stopped generating tokens.
+ pub finish_reason: String,
+ /// The index of the choice in the list of choices.
+ pub index: usize,
+ /// A chat completion message generated by the model.
+ pub message: ChatCompletionMessage,
+ /// Log probability information for the choice.
+ pub logprobs: Option,
+}
+
+/// A chat completion message generated by the model.
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ChatCompletionMessage {
+ /// The contents of the message.
+ pub content: Option,
+ /// The refusal message generated by the model.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub refusal: Option,
+ pub tool_calls: Vec,
+ /// The role of the author of this message.
+ pub role: String,
+}
+
+#[derive(Debug, Clone, Deserialize, Serialize)]
+pub struct ChatCompletionLogprobs {
+ /// A list of message content tokens with log probability information.
+ pub content: Option>,
+ /// A list of message refusal tokens with log probability information.
+ pub refusal: Option>,
+}
+
+/// Log probability information for a choice.
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ChatCompletionLogprob {
+ /// The token.
+ pub token: String,
+ /// The log probability of this token.
+ pub logprob: f32,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub bytes: Option>,
+ /// List of the most likely tokens and their log probability, at this token position.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub top_logprobs: Option>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ChatCompletionTopLogprob {
+ /// The token.
+ pub token: String,
+ /// The log probability of this token.
+ pub logprob: f32,
+}
+
+/// Represents a streamed chunk of a chat completion response returned by model, based on the provided input.
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ChatCompletionChunk {
+ /// A unique identifier for the chat completion. Each chunk has the same ID.
+ pub id: String,
+ /// A list of chat completion choices.
+ pub choices: Vec,
+ /// The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same timestamp.
+ pub created: i64,
+ /// The model to generate the completion.
+ pub model: String,
+ /// The service tier used for processing the request.
+ /// This field is only included if the service_tier parameter is specified in the request.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub service_tier: Option,
+ /// This fingerprint represents the backend configuration that the model runs with.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub system_fingerprint: Option,
+ /// The object type, which is always `chat.completion.chunk`.
+ pub object: String,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub usage: Option,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ChatCompletionChunkChoice {
+ /// A chat completion delta generated by streamed model responses.
+ pub delta: ChatCompletionDelta,
+ /// Log probability information for the choice.
+ pub logprobs: Option,
+ /// The reason the model stopped generating tokens.
+ pub finish_reason: Option,
+ /// The index of the choice in the list of choices.
+ pub index: u32,
+}
+
+/// A chat completion delta generated by streamed model responses.
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ChatCompletionDelta {
+ /// The contents of the message.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub content: Option,
+ /// The refusal message generated by the model.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub refusal: Option,
+ #[serde(skip_serializing_if = "Vec::is_empty")]
+ pub tool_calls: Vec,
+ /// The role of the author of this message.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub role: Option,
+}
+
+/// Usage statistics for a completion.
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct Usage {
+ /// Number of tokens in the generated completion.
+ pub completion_tokens: u32,
+ /// Number of tokens in the prompt.
+ pub prompt_tokens: u32,
+ /// Total number of tokens used in the request (prompt + completion).
+ pub total_tokens: u32,
+ /// Breakdown of tokens used in a completion.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub completion_token_details: Option,
+ /// Breakdown of tokens used in the prompt.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub prompt_token_details: Option,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct CompletionTokenDetails {
+ pub audio_tokens: u32,
+ pub reasoning_tokens: u32,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct PromptTokenDetails {
+ pub audio_tokens: u32,
+ pub cached_tokens: u32,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+#[serde(untagged)]
+pub enum StopTokens {
+ Array(Vec),
+ String(String),
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct OpenAiError {
+ pub object: Option,
+ pub message: String,
+ #[serde(rename = "type")]
+ pub r#type: Option,
+ pub param: Option