diff --git a/Cargo.lock b/Cargo.lock index 72ede252..4bbb6417 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -555,6 +555,17 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "eventsource-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" +dependencies = [ + "futures-core", + "nom", + "pin-project-lite", +] + [[package]] name = "fastrand" version = "2.1.1" @@ -604,11 +615,17 @@ dependencies = [ "faux", "futures", "ginepro", + "http-serde", "hyper", "hyper-util", "mio", + "opentelemetry", + "opentelemetry-http", + "opentelemetry-otlp", + "opentelemetry_sdk", "prost", "reqwest", + "reqwest-eventsource", "rustls", "rustls-pemfile", "rustls-webpki", @@ -621,8 +638,10 @@ dependencies = [ "tokio-stream", "tonic", "tonic-build", + "tower-http", "tower-service", "tracing", + "tracing-opentelemetry", "tracing-subscriber", "tracing-test", "url", @@ -736,6 +755,12 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.30" @@ -902,6 +927,16 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-serde" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f056c8559e3757392c8d091e796416e4649d8e49e88b8d76df6c002f05027fd" +dependencies = [ + "http 1.1.0", + "serde", +] + [[package]] name = "httparse" version = "1.9.5" @@ -1372,6 +1407,85 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "opentelemetry" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c365a63eec4f55b7efeceb724f1336f26a9cf3427b70e59e2cd2a5b947fba96" +dependencies = [ + "futures-core", + "futures-sink", + "js-sys", + "once_cell", + "pin-project-lite", + "thiserror", +] + +[[package]] +name = "opentelemetry-http" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad31e9de44ee3538fb9d64fe3376c1362f406162434609e79aea2a41a0af78ab" +dependencies = [ + "async-trait", + "bytes", + "http 1.1.0", + "opentelemetry", + "reqwest", +] + +[[package]] +name = "opentelemetry-otlp" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b925a602ffb916fb7421276b86756027b37ee708f9dce2dbdcc51739f07e727" +dependencies = [ + "async-trait", + "futures-core", + "http 1.1.0", + "opentelemetry", + "opentelemetry-http", + "opentelemetry-proto", + "opentelemetry_sdk", + "prost", + "thiserror", + "tokio", + "tonic", +] + +[[package]] +name = "opentelemetry-proto" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30ee9f20bff9c984511a02f082dc8ede839e4a9bf15cc2487c8d6fea5ad850d9" +dependencies = [ + "opentelemetry", + "opentelemetry_sdk", + "prost", + "tonic", +] + +[[package]] +name = "opentelemetry_sdk" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "692eac490ec80f24a17828d49b40b60f5aeaccdfe6a503f939713afd22bc28df" +dependencies = [ + "async-trait", + "futures-channel", + "futures-executor", + "futures-util", + "glob", + "once_cell", + "opentelemetry", + "percent-encoding", + "rand", + "serde_json", + "thiserror", + "tokio", + "tokio-stream", +] + [[package]] name = "overload" version = "0.1.1" @@ -1750,15 +1864,33 @@ dependencies = [ "tokio", "tokio-native-tls", "tokio-rustls", + "tokio-util", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "webpki-roots", "windows-registry", ] +[[package]] +name = "reqwest-eventsource" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde" +dependencies = [ + "eventsource-stream", + "futures-core", + "futures-timer", + "mime", + "nom", + "pin-project-lite", + "reqwest", + "thiserror", +] + [[package]] name = "reserve-port" version = "2.0.1" @@ -2380,6 +2512,23 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-http" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" +dependencies = [ + "bitflags", + "bytes", + "http 1.1.0", + "http-body", + "http-body-util", + "pin-project-lite", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower-layer" version = "0.3.3" @@ -2436,6 +2585,24 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-opentelemetry" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9784ed4da7d921bc8df6963f8c80a0e4ce34ba6ba76668acadd3edbd985ff3b" +dependencies = [ + "js-sys", + "once_cell", + "opentelemetry", + "opentelemetry_sdk", + "smallvec", + "tracing", + "tracing-core", + "tracing-log", + "tracing-subscriber", + "web-time", +] + [[package]] name = "tracing-serde" version = "0.1.3" @@ -2703,6 +2870,19 @@ version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" +[[package]] +name = "wasm-streams" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e072d4e72f700fb3443d8fe94a39315df013eef1104903cdb0a2abd322bbecd" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.70" @@ -2713,6 +2893,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki-roots" version = "0.26.6" diff --git a/Cargo.toml b/Cargo.toml index b4ae3e1e..310c2c62 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,16 +15,24 @@ path = "src/main.rs" [dependencies] anyhow = "1.0.86" +async-trait = "0.1.81" +async-stream = "0.3.5" axum = { version = "0.7.5", features = ["json"] } axum-extra = "0.9.3" clap = { version = "4.5.15", features = ["derive", "env"] } futures = "0.3.30" ginepro = "0.8.1" +http-serde = "2.1.1" hyper = { version = "1.4.1", features = ["http1", "http2", "server"] } hyper-util = { version = "0.1.7", features = ["server-auto", "server-graceful", "tokio"] } mio = "1.0.2" +opentelemetry = { version = "0.24.0", features = ["trace", "metrics"] } +opentelemetry-http = { version = "0.13.0", features = ["reqwest"] } +opentelemetry-otlp = { version = "0.17.0", features = ["http-proto"] } +opentelemetry_sdk = { version = "0.24.1", features = ["rt-tokio", "metrics"] } prost = "0.13.1" reqwest = { version = "0.12.5", features = ["blocking", "rustls-tls", "json"] } +reqwest-eventsource = "0.6.0" rustls = {version = "0.23.12", default-features = false, features = ["std"]} rustls-pemfile = "2.1.3" rustls-webpki = "0.102.6" @@ -36,13 +44,13 @@ tokio = { version = "1.39.2", features = ["rt", "rt-multi-thread", "parking_lot" tokio-rustls = { version = "0.26.0" } tokio-stream = { version = "0.1.15", features = ["sync"] } tonic = { version = "0.12.1", features = ["tls", "tls-roots", "tls-webpki-roots"] } +tower-http = { version = "0.5.2", features = ["trace"] } tower-service = "0.3" tracing = "0.1.40" +tracing-opentelemetry = "0.25.0" tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] } url = "2.5.2" uuid = { version = "1.10.0", features = ["v4", "fast-rng"] } -async-trait = "0.1.81" -async-stream = "0.3.5" [build-dependencies] tonic-build = "0.12.1" diff --git a/Dockerfile b/Dockerfile index 6caee35c..d1913a4c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,7 @@ ARG UBI_MINIMAL_BASE_IMAGE=registry.access.redhat.com/ubi9/ubi-minimal ARG UBI_BASE_IMAGE_TAG=latest ARG PROTOC_VERSION=26.0 +ARG CONFIG_FILE=config/config.yaml ## Rust builder ################################################################ # Specific debian version so that compatible glibc version is used @@ -23,10 +24,10 @@ RUN rustup component add rustfmt ## Orchestrator builder ######################################################### FROM rust-builder as fms-guardrails-orchestr8-builder -COPY build.rs *.toml LICENSE /app -COPY config/ /app/config -COPY protos/ /app/protos -COPY src/ /app/src +COPY build.rs *.toml LICENSE /app/ +COPY ${CONFIG_FILE} /app/config/config.yaml +COPY protos/ /app/protos/ +COPY src/ /app/src/ WORKDIR /app @@ -50,7 +51,7 @@ RUN cargo fmt --check FROM ${UBI_MINIMAL_BASE_IMAGE}:${UBI_BASE_IMAGE_TAG} as fms-guardrails-orchestr8-release COPY --from=fms-guardrails-orchestr8-builder /app/bin/ /app/bin/ -COPY config /app/config +COPY ${CONFIG_FILE} /app/config/config.yaml RUN microdnf install -y --disableplugin=subscription-manager shadow-utils compat-openssl11 && \ microdnf clean all --disableplugin=subscription-manager diff --git a/config/config.yaml b/config/config.yaml index 2fe6a657..ef5be2c4 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -11,6 +11,12 @@ generation: service: hostname: localhost port: 8033 +# Generation server used for chat endpoints +# chat_generation: +# service: +# hostname: localhost +# port: 8080 +# # health_service: # Any chunker servers that will be used by any detectors chunkers: # Chunker ID/name @@ -26,11 +32,16 @@ chunkers: detectors: # Detector ID/name to be used in user requests hap-en: + # Detector type (text_contents, text_generation, text_chat, text_context_doc) + type: text_contents service: - hostname: https://localhost/api/v1/text/contents # Full url / endpoint currently expected + hostname: localhost port: 8080 # TLS ID/name, optional (detailed in `tls` section) tls: detector + health_service: + hostname: localhost + port: 8081 # Chunker ID/name from `chunkers` section if applicable chunker_id: en_regex # Default score threshold for a detector. If a user @@ -53,8 +64,7 @@ tls: detector_bundle_no_ca: cert_path: /path/to/client-bundle.pem insecure: true - # Following section can be used to configure the allowed headers that orchestrator will pass to # NLP provider and detectors. Note that, this section takes header keys, not values. # passthrough_headers: -# - header-key \ No newline at end of file +# - header-key diff --git a/config/test.config.yaml b/config/test.config.yaml index f32a0a26..cdde1905 100644 --- a/config/test.config.yaml +++ b/config/test.config.yaml @@ -3,6 +3,10 @@ generation: service: hostname: localhost port: 443 +# chat_generation: +# service: +# hostname: localhost +# port: 8080 chunkers: test_chunker: type: sentence @@ -11,8 +15,9 @@ chunkers: port: 8085 detectors: test_detector: + type: text_contents service: - hostname: https://localhost/api/v1/text/contents + hostname: localhost port: 8000 chunker_id: test_chunker - default_threshold: 0.5 \ No newline at end of file + default_threshold: 0.5 diff --git a/docs/api/openapi_detector_api.yaml b/docs/api/openapi_detector_api.yaml index 9bd96f0d..fb64ea23 100644 --- a/docs/api/openapi_detector_api.yaml +++ b/docs/api/openapi_detector_api.yaml @@ -5,9 +5,14 @@ info: name: Apache 2.0 url: https://www.apache.org/licenses/LICENSE-2.0.html version: 0.0.1 +tags: + - name: Text + description: Detections on text paths: /api/v1/text/contents: post: + tags: + - Text summary: Text Content Analysis Unary Handler description: >- Detectors that work on content text, be it user prompt or generated @@ -67,6 +72,8 @@ paths: $ref: '#/components/schemas/Error' /api/v1/text/generation: post: + tags: + - Text summary: Generation Analysis Unary Handler description: >- Detectors that run on prompt and text generation output.
@@ -115,6 +122,8 @@ paths: $ref: '#/components/schemas/Error' /api/v1/text/chat: post: + tags: + - Text summary: Chat Analysis Unary Handler description: >- Detectors that analyze chat messages and provide detections
@@ -162,6 +171,8 @@ paths: $ref: '#/components/schemas/Error' /api/v1/text/context/doc: post: + tags: + - Text summary: Context Analysis Unary Handler description: >- Detectors that work on a context created by document(s).
diff --git a/docs/api/orchestrator_openapi_0_1_0.yaml b/docs/api/orchestrator_openapi_0_1_0.yaml index d7196d3e..cd27db1d 100644 --- a/docs/api/orchestrator_openapi_0_1_0.yaml +++ b/docs/api/orchestrator_openapi_0_1_0.yaml @@ -3,12 +3,12 @@ info: title: FMS Orchestrator API version: 0.1.0 tags: - - name: Task - Text Generation, with detection - description: Detections on text generation model input and/or output - - name: Task - Detection - description: Standalone detections - - name: Task - Chat Completions, with detection - description: Detections on list of messages comprising a conversation and/or completions from a model + - name: Task - Text Generation, with detection + description: Detections on text generation model input and/or output + - name: Task - Detection + description: Standalone detections + - name: Task - Chat Completions, with detection + description: Detections on list of messages comprising a conversation and/or completions from a model paths: /health: get: @@ -17,7 +17,7 @@ paths: summary: Performs quick liveliness check of the orchestrator service operationId: health responses: - '200': + "200": description: Healthy content: application/json: @@ -42,18 +42,18 @@ paths: type: boolean default: false responses: - '200': + "200": description: Orchestrator successfully probed client health statuses content: application/json: schema: - $ref: '#/components/schemas/HealthProbeResponse' - '503': + $ref: "#/components/schemas/InfoResponse" + "503": description: Orchestrator failed to probe client health statuses content: application/json: schema: - $ref: '#/components/schemas/HealthProbeResponse' + $ref: "#/components/schemas/InfoResponse" /api/v1/task/classification-with-text-generation: post: tags: @@ -65,27 +65,27 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/GuardrailsHttpRequest' + $ref: "#/components/schemas/GuardrailsHttpRequest" required: true responses: - '200': + "200": description: Successful Response content: application/json: schema: - $ref: '#/components/schemas/ClassifiedGeneratedTextResult' - '404': + $ref: "#/components/schemas/ClassifiedGeneratedTextResult" + "404": description: Resource Not Found content: application/json: schema: - $ref: '#/components/schemas/Error' - '422': + $ref: "#/components/schemas/Error" + "422": description: Validation Error content: application/json: schema: - $ref: '#/components/schemas/Error' + $ref: "#/components/schemas/Error" /api/v1/task/server-streaming-classification-with-text-generation: post: tags: @@ -97,27 +97,27 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/GuardrailsHttpRequest' + $ref: "#/components/schemas/GuardrailsHttpRequest" required: true responses: - '200': + "200": description: Successful Response content: - application/json: + text/event-stream: schema: - $ref: '#/components/schemas/ClassifiedGeneratedTextStreamResult' - '404': + $ref: "#/components/schemas/ClassifiedGeneratedTextStreamResult" + "404": description: Resource Not Found content: - application/json: + text/event-stream: schema: - $ref: '#/components/schemas/Error' - '422': + $ref: "#/components/schemas/Error" + "422": description: Validation Error content: - application/json: + text/event-stream: schema: - $ref: '#/components/schemas/Error' + $ref: "#/components/schemas/Error" /api/v2/text/generation-detection: post: tags: @@ -129,27 +129,27 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/GenerationDetectionRequest' + $ref: "#/components/schemas/GenerationDetectionRequest" required: true responses: - '200': + "200": description: Successful Response content: application/json: schema: - $ref: '#/components/schemas/GenerationDetectionResponse' - '404': + $ref: "#/components/schemas/GenerationDetectionResponse" + "404": description: Resource Not Found content: application/json: schema: - $ref: '#/components/schemas/Error' - '422': + $ref: "#/components/schemas/Error" + "422": description: Validation Error content: application/json: schema: - $ref: '#/components/schemas/Error' + $ref: "#/components/schemas/Error" /api/v2/text/detection/content: post: @@ -162,27 +162,77 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/DetectionContentRequest' + $ref: "#/components/schemas/DetectionContentRequest" required: true responses: - '200': + "200": description: Successful Response content: application/json: schema: - $ref: '#/components/schemas/DetectionContentResponse' - '404': + $ref: "#/components/schemas/DetectionContentResponse" + "404": description: Resource Not Found content: application/json: schema: - $ref: '#/components/schemas/Error' - '422': + $ref: "#/components/schemas/Error" + "422": description: Validation Error content: application/json: schema: - $ref: '#/components/schemas/Error' + $ref: "#/components/schemas/Error" + /api/v2/text/detection/stream-content: + post: + tags: + - Task - Detection + summary: Detection task on input content stream + operationId: >- + api_v2_detection_text_content_bidi_stream_handler + requestBody: + content: + application/x-ndjson: + schema: + oneOf: + - $ref: "#/components/schemas/DetectionContentRequest" + - $ref: "#/components/schemas/DetectionContentStreamEvent" + # In OpenAPI 3.0, examples cannot be present in schemas, + # whereas object level examples are present in 3.1 + examples: + first_event: + summary: First text event with detectors + value: + detectors: + hap-v1-model-en: {} + content: "my text here" + text: + summary: Regular text event + value: + content: "my text here" + required: true + responses: + "200": + description: Successful Response + content: + text/event-stream: + schema: + # NOTE: This endpoint, like the + # `server-streaming-classification-with-text-generation` + # endpoint will produce streamed events + $ref: "#/components/schemas/DetectionContentStreamResponse" + "404": + description: Resource Not Found + content: + text/event-stream: + schema: + $ref: "#/components/schemas/Error" + "422": + description: Validation Error + content: + text/event-stream: + schema: + $ref: "#/components/schemas/Error" /api/v2/text/detection/chat: post: tags: @@ -194,27 +244,27 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/DetectionChatRequest' + $ref: "#/components/schemas/DetectionChatRequest" required: true responses: - '200': + "200": description: Successful Response content: application/json: schema: - $ref: '#/components/schemas/DetectionChatResponse' - '404': + $ref: "#/components/schemas/DetectionChatResponse" + "404": description: Resource Not Found content: application/json: schema: - $ref: '#/components/schemas/Error' - '422': + $ref: "#/components/schemas/Error" + "422": description: Validation Error content: application/json: schema: - $ref: '#/components/schemas/Error' + $ref: "#/components/schemas/Error" /api/v2/text/detection/context: post: @@ -227,27 +277,27 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/DetectionContextDocsRequest' + $ref: "#/components/schemas/DetectionContextDocsRequest" required: true responses: - '200': + "200": description: Successful Response content: application/json: schema: - $ref: '#/components/schemas/DetectionContextDocsResponse' - '404': + $ref: "#/components/schemas/DetectionContextDocsResponse" + "404": description: Resource Not Found content: application/json: schema: - $ref: '#/components/schemas/Error' - '422': + $ref: "#/components/schemas/Error" + "422": description: Validation Error content: application/json: schema: - $ref: '#/components/schemas/Error' + $ref: "#/components/schemas/Error" /api/v2/text/detection/generated: post: tags: @@ -259,126 +309,109 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/GeneratedTextDetectionRequest' + $ref: "#/components/schemas/GeneratedTextDetectionRequest" required: true responses: - '200': + "200": description: Successful Response content: application/json: schema: - $ref: '#/components/schemas/GeneratedTextDetectionResponse' - '404': + $ref: "#/components/schemas/GeneratedTextDetectionResponse" + "404": description: Resource Not Found content: application/json: schema: - $ref: '#/components/schemas/Error' - '422': + $ref: "#/components/schemas/Error" + "422": description: Validation Error content: application/json: schema: - $ref: '#/components/schemas/Error' - + $ref: "#/components/schemas/Error" + /api/v2/chat/completions-detection: post: - tags: - - Task - Chat Completions, with detection - operationId: >- - api_v2_chat_completions_detection_handler - summary: Creates a model response with detections for the given chat conversation. - requestBody: - required: true - content: - application/json: - schema: - $ref: "#/components/schemas/GuardrailsCreateChatCompletionRequest" - responses: - '200': - description: Successful Response - content: - application/json: - schema: - $ref: "#/components/schemas/GuardrailsCreateChatCompletionResponse" - '404': - description: Resource Not Found - content: - application/json: - schema: - $ref: '#/components/schemas/Error' - '422': - description: Validation Error - content: - application/json: - schema: - $ref: '#/components/schemas/Error' + tags: + - Task - Chat Completions, with detection + operationId: >- + api_v2_chat_completions_detection_handler + summary: Creates a model response with detections for the given chat conversation. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/GuardrailsCreateChatCompletionRequest" + responses: + "200": + description: Successful Response + content: + application/json: + schema: + $ref: "#/components/schemas/GuardrailsCreateChatCompletionResponse" + "404": + description: Resource Not Found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "422": + description: Validation Error + content: + application/json: + schema: + $ref: "#/components/schemas/Error" components: schemas: HealthStatus: - type: string - enum: - - HEALTHY - - UNHEALTHY - - UNKNOWN - title: Health Status + type: string + enum: + - HEALTHY + - UNHEALTHY + - UNKNOWN + title: Health Status HealthCheckResult: - oneOf: - - properties: - health_status: - $ref: '#/components/schemas/HealthStatus' - response_code: - type: string - title: Response Code - example: "HTTP 200 OK" - reason: - type: string - title: Reason - example: "Service not found" - required: - - health_status - - response_code - - additionalProperties: - $ref: '#/components/schemas/HealthStatus' + properties: + status: + $ref: "#/components/schemas/HealthStatus" + code: + type: string + title: Response Code + example: 200 + reason: + type: string + title: Reason + example: "Not Found" + required: + - status type: object - title: Health Check Response - HealthProbeResponse: + title: Health Check Result + InfoResponse: properties: services: type: object title: Health status for each client service - properties: - generation: - type: object - title: Generation Services - items: - $ref: '#/components/schemas/HealthCheckResult' - detectors: - type: object - title: Detector Services - items: - $ref: '#/components/schemas/HealthCheckResult' - chunkers: - type: object - title: Chunker Services - items: - $ref: '#/components/schemas/HealthCheckResult' + items: + $ref: "#/components/schemas/HealthCheckResult" required: - services type: object - title: Health Probe Response + title: Info Response DetectionContentRequest: properties: detectors: type: object title: Detectors default: {} - example: + example: hap-v1-model-en: {} content: type: string title: Content + example: "my text here" required: ["detectors", "content"] additionalProperties: false type: object @@ -388,7 +421,7 @@ components: detections: type: array items: - $ref: '#/components/schemas/DetectionContentResponseObject' + $ref: "#/components/schemas/DetectionContentResponseObject" additionalProperties: false required: ["detections"] type: object @@ -415,15 +448,38 @@ components: title: Score title: Content Detection Response Object example: - - { - "start": 0, - "end": 20, - "text": "string", - "detection_type": "HAP", - "detection": "has_HAP", - "detector_id": "hap-v1-model-en", - "score": 0.999 - } + start: 0 + end: 20 + text: "string" + detection_type: "HAP" + detection: "has_HAP" + detector_id: "hap-v1-model-en" + score: 0.999 + DetectionContentStreamEvent: + properties: + content: + type: string + title: Content + example: "my text here" + required: ["content"] + type: object + description: Individual stream event + title: Content Detection Stream Event + DetectionContentStreamResponse: + properties: + detections: + type: array + items: + $ref: "#/components/schemas/DetectionContentResponseObject" + processed_index: + anyOf: + - type: integer + title: Processed Index + start_index: + type: integer + title: Start Index + type: object + title: Content Detection Stream Response DetectionChatRequest: properties: @@ -431,7 +487,7 @@ components: type: object title: Detectors default: {} - example: + example: chat-v1-model-en: {} messages: title: Chat Messages @@ -465,14 +521,14 @@ components: title: Detections on entire history of chat messages title: Chat Detection Response required: ["detections"] - + DetectionContextDocsRequest: properties: detectors: type: object title: Detectors default: {} - example: + example: context-v1-model-en: {} content: type: string @@ -500,7 +556,7 @@ components: detections: type: array items: - $ref: '#/components/schemas/DetectionContextDocsResponseObject' + $ref: "#/components/schemas/DetectionContextDocsResponseObject" required: ["detections"] title: Context Docs Detection Response DetectionContextDocsResponseObject: @@ -516,9 +572,9 @@ components: title: Score evidence: anyOf: - - items: - $ref: '#/components/schemas/EvidenceObj' - type: array + - items: + $ref: "#/components/schemas/EvidenceObj" + type: array title: Context Docs Detection Response Object GenerationDetectionRequest: @@ -533,11 +589,11 @@ components: type: object title: Detectors default: {} - example: + example: generation-detection-v1-model-en: {} text_gen_parameters: allOf: - - $ref: '#/components/schemas/GuardrailsTextGenerationParameters' + - $ref: "#/components/schemas/GuardrailsTextGenerationParameters" type: object required: ["model_id", "prompt", "detectors"] title: Generation-Detection Request @@ -566,7 +622,7 @@ components: title: Input token Count title: Generation Detection Response required: ["generated_text", "detections"] - + GeneratedTextDetectionRequest: properties: prompt: @@ -579,7 +635,7 @@ components: type: object title: Detectors default: {} - example: + example: generated-detection-v1-model-en: {} type: object required: ["generated_text", "prompt", "detectors"] @@ -589,7 +645,7 @@ components: detections: type: array items: - $ref: '#/components/schemas/GeneratedTextDetectionResponseObject' + $ref: "#/components/schemas/GeneratedTextDetectionResponseObject" required: ["detections"] title: Generated Text Detection Response GeneratedTextDetectionResponseObject: @@ -614,10 +670,10 @@ components: title: Generated Text token_classification_results: anyOf: - - $ref: '#/components/schemas/TextGenTokenClassificationResults' + - $ref: "#/components/schemas/TextGenTokenClassificationResults" finish_reason: anyOf: - - $ref: '#/components/schemas/FinishReason' + - $ref: "#/components/schemas/FinishReason" generated_token_count: anyOf: - type: integer @@ -633,19 +689,19 @@ components: warnings: anyOf: - items: - $ref: '#/components/schemas/InputWarning' + $ref: "#/components/schemas/InputWarning" type: array title: Warnings tokens: anyOf: - items: - $ref: '#/components/schemas/GeneratedToken' + $ref: "#/components/schemas/GeneratedToken" type: array title: Tokens input_tokens: anyOf: - items: - $ref: '#/components/schemas/GeneratedToken' + $ref: "#/components/schemas/GeneratedToken" type: array title: Input Tokens additionalProperties: false @@ -660,10 +716,10 @@ components: title: Generated Text token_classification_results: anyOf: - - $ref: '#/components/schemas/TextGenTokenClassificationResults' + - $ref: "#/components/schemas/TextGenTokenClassificationResults" finish_reason: anyOf: - - $ref: '#/components/schemas/FinishReason' + - $ref: "#/components/schemas/FinishReason" generated_token_count: anyOf: - type: integer @@ -679,19 +735,19 @@ components: warnings: anyOf: - items: - $ref: '#/components/schemas/InputWarning' + $ref: "#/components/schemas/InputWarning" type: array title: Warnings tokens: anyOf: - items: - $ref: '#/components/schemas/GeneratedToken' + $ref: "#/components/schemas/GeneratedToken" type: array title: Tokens input_tokens: anyOf: - items: - $ref: '#/components/schemas/GeneratedToken' + $ref: "#/components/schemas/GeneratedToken" type: array title: Input Tokens processed_index: @@ -706,18 +762,18 @@ components: type: object title: Classified Generated Text Stream Result TextGenTokenClassificationResults: - # By default open-api spec consider all fields as optional + # By default open-api spec consider all fields as optional properties: input: anyOf: - items: - $ref: '#/components/schemas/TokenClassificationResult' + $ref: "#/components/schemas/TokenClassificationResult" type: array title: Input output: anyOf: - items: - $ref: '#/components/schemas/TokenClassificationResult' + $ref: "#/components/schemas/TokenClassificationResult" type: array title: Output additionalProperties: false @@ -765,7 +821,7 @@ components: default: {} required: - detectors - + GuardrailsCreateChatCompletionResponse: title: Guardrails Chat Completion Response description: Guardrails chat completion response (adds detections on OpenAI chat completion) @@ -778,7 +834,7 @@ components: warnings: type: array items: - $ref: '#/components/schemas/Warning' + $ref: "#/components/schemas/Warning" required: - detections @@ -800,27 +856,27 @@ components: output: pii-v1: {} conversation-detector: {} - + ChatCompletionsDetections: title: Chat Completions Detections properties: input: type: array items: - $ref: '#/components/schemas/MessageDetections' + $ref: "#/components/schemas/MessageDetections" title: Detections on input to chat completions default: {} output: type: array items: - $ref: '#/components/schemas/ChoiceDetections' + $ref: "#/components/schemas/ChoiceDetections" title: Detections on output of chat completions default: {} default: {} example: input: - message_index: 0 - results: + results: - { "start": 0, "end": 80, @@ -828,7 +884,7 @@ components: "detection_type": "HAP", "detection": "has_HAP", "detector_id": "hap-v1-model-en", # Future addition - "score": 0.999 + "score": 0.999, } output: - choice_index: 0 @@ -841,15 +897,15 @@ components: "detection_type": "HAP", "detection": "has_HAP", "detector_id": "hap-v1-model-en", # Future addition - "score": 0.999 + "score": 0.999, } - { "detection_type": "string", "detection": "string", "detector_id": "relevance-v1-en", # Future addition - "score": 0 + "score": 0, } - + MessageDetections: title: Message Detections properties: @@ -863,9 +919,9 @@ components: type: array items: anyOf: - - $ref: '#/components/schemas/DetectionContentResponseObject' - - $ref: '#/components/schemas/DetectionContextDocsResponseObject' - - $ref: '#/components/schemas/GeneratedTextDetectionResponseObject' + - $ref: "#/components/schemas/DetectionContentResponseObject" + - $ref: "#/components/schemas/DetectionContextDocsResponseObject" + - $ref: "#/components/schemas/GeneratedTextDetectionResponseObject" required: - message_index ChoiceDetections: @@ -881,9 +937,9 @@ components: type: array items: anyOf: - - $ref: '#/components/schemas/DetectionContentResponseObject' - - $ref: '#/components/schemas/DetectionContextDocsResponseObject' - - $ref: '#/components/schemas/GeneratedTextDetectionResponseObject' + - $ref: "#/components/schemas/DetectionContentResponseObject" + - $ref: "#/components/schemas/DetectionContextDocsResponseObject" + - $ref: "#/components/schemas/GeneratedTextDetectionResponseObject" required: - choice_index @@ -948,7 +1004,7 @@ components: evidence: anyOf: - items: - $ref: '#/components/schemas/Evidence' + $ref: "#/components/schemas/Evidence" type: array type: object required: @@ -1010,7 +1066,7 @@ components: title: Inputs guardrail_config: allOf: - - $ref: '#/components/schemas/GuardrailsConfig' + - $ref: "#/components/schemas/GuardrailsConfig" default: input: masks: [] @@ -1019,7 +1075,7 @@ components: models: {} text_gen_parameters: allOf: - - $ref: '#/components/schemas/GuardrailsTextGenerationParameters' + - $ref: "#/components/schemas/GuardrailsTextGenerationParameters" type: object required: - model_id @@ -1059,7 +1115,7 @@ components: title: Max Time exponential_decay_length_penalty: allOf: - - $ref: '#/components/schemas/ExponentialDecayLengthPenalty' + - $ref: "#/components/schemas/ExponentialDecayLengthPenalty" stop_sequences: items: type: string @@ -1094,7 +1150,7 @@ components: properties: id: allOf: - - $ref: '#/components/schemas/InputWarningReason' + - $ref: "#/components/schemas/InputWarningReason" message: type: string title: Message diff --git a/docs/architecture/adrs/005-chat-completion-support.md b/docs/architecture/adrs/005-chat-completion-support.md index d04e92da..6fb5156a 100644 --- a/docs/architecture/adrs/005-chat-completion-support.md +++ b/docs/architecture/adrs/005-chat-completion-support.md @@ -73,4 +73,4 @@ This means that the orchestrator will have to be able to track chunking and dete ## Status -Proposed +Accepted diff --git a/docs/architecture/adrs/006-detector-type.md b/docs/architecture/adrs/006-detector-type.md new file mode 100644 index 00000000..99c4e077 --- /dev/null +++ b/docs/architecture/adrs/006-detector-type.md @@ -0,0 +1,45 @@ +# ADR 006: Detector Type + +This ADR documents the decision of adding the `type` parameter for detectors in the orchestrator config. + +## Motivation + +The guardrails orchestrator interfaces with different types of detectors. +Detectors of a given type are compatible with only a subset of orchestrator endpoints. +In order to reduce changes of misconfiguration, we need a way to map detectors to be used only with compatible endpoints. This would additionally provide a way for us to refer to a particular detector type within the code, without looking at its `hostname` (url) , which can be error prone. Good example for this is validating if certain detector would work with certain orchestrator endpoint or not. + + +## Decision + +We decided to add the `type` parameter to the detectors configuration. +Possible values are `text_contents`, `text_chat`, `text_generation` and `text_context_doc`. +Below is an example of detector configuration. + +```yaml +detectors: + my_detector: + type: text_contents # Options: text_contents, text_context_chat, text_context_doc, text_generation + service: + hostname: my-detector.com + port: 8080 + tls: my_certs + chunker_id: my_chunker + default_threshold: 0.5 +``` + +## Consequences + +1. Reduced misconfiguration risk. +2. Future logic can be implemented for detectors of a particular type. +3. `hostname` no longer needs the full URL, but only the actual hostname. +4. If `tls` is provided, the `https` protocol is used. `http`, otherwise. +5. Not including `type` results in a configuration validation error on orchestrator startup. +6. Detector endpoints are automatically configured based on `type` as follows: + * `text_contents` -> `/api/v1/text/contents` + * `text_chat` -> `/api/v1/text/chat` + * `text_context_doc` -> `/api/v1/text/context/doc` + * `text_generation` -> `/api/v1/text/generation` + +## Status + +Accepted \ No newline at end of file diff --git a/docs/architecture/adrs/007-orchestrator-telemetry.md b/docs/architecture/adrs/007-orchestrator-telemetry.md new file mode 100644 index 00000000..19faffac --- /dev/null +++ b/docs/architecture/adrs/007-orchestrator-telemetry.md @@ -0,0 +1,96 @@ +# ADR 007: Orchestrator Telemetry + +The guardrails orchestrator uses [OpenTelemetry](https://opentelemetry.io/) to collect and export telemetry data (traces, metrics, and logs). + +The orchestrator needs to collect telemetry data for monitoring and observability. It also needs to be able to trace +spans for incoming requests and across client requests to configured detectors, chunkers, and generation services and +aggregate detailed traces, metrics, and logs that can be monitored from a variety of observability backends. + +## Decision + +### OpenTelemetry and `tracing` + +The orchestrator and client services will make use of the OpenTelemetry SDK and the [OpenTelemetry Protocol (OTLP)](https://opentelemetry.io/docs/specs/otel/protocol/) +for consolidating and collecting telemetry data across services. The orchestrator will be responsible for collecting +telemetry data throughout the lifetime of a request using the `tracing` crate, which is the de facto choice for logging +and tracing for OpenTelemetry in Rust, and exporting it through the OTLP exporter if configured. The OTLP exporter will +send telemetry data to a gRPC or HTTP endpoint that can be configured to point to a running OpenTelemetry (OTEL) collector. +Similarly, detectors should also be able to collect and export telemetry through OTLP to the same OTEL collector. +From the OTEL collector, the telemetry data can then be exported to multiple backends. The OTEL collector and +any observability backends can all be configured alongside the orchestrator and detectors in a deployment. + +### Instrumentation +An incoming request to the orchestrator will initialize a new trace, therefore a trace-id and request should be in +one-to-one correspondence. All important functions throughout the control flow of handling a request in the orchestrator +will be instrumented with the `#[tracing::instrument]` attribute macro above the function definition. This will create +and enter a span for each function call and add it to the trace of the request. Here, important functions refers to any +functions that perform important business logic that may incur significant latency, including all the handler functions +for incoming and outgoing requests. It is up to the discretion of the developer to determine what functions are +"significant" enough to indicate a new span in the trace, but adding a new tracing span can always trivially be done by +just adding the instrument macro. + +### Metrics +The orchestrator will aggregate metrics regarding the requests it has received/handled, and annotate the metrics with +span attributes allowing for detailed filtering and monitoring. The metrics will be exported through the OTLP exporter +through the metrics provider. Traces exported through the traces provider can also have R.E.D. (request, error and +duration) metrics attached to them implicitly by the OTEL collector using the `spanmetrics` connector. Both the OTLP +metrics and the `spanmetrics` metrics can be exported to configured metrics backends like Prometheus or Grafana. +The orchestrator will handle a variety of useful metrics such as counters and histograms for received/handled +successful/failed requests, request and stream durations, and server/client errors. Traces and metrics will also relate +incoming orchestrator requests to respective client requests/responses, and collect more business specific metrics +e.g. regarding the outcome of running detection or generation. + +### Configuration +The orchestrator will expose CLI args/env variables for configuring the OTLP exporter: +- `OTEL_EXPORTER_OTLP_PROTOCOL=grpc|http` to set the protocol for all the OTLP endpoints + - `OTEL_EXPORTER_OTLP_TRACES_PROTOCOL` and `OTEL_EXPORTER_OTLP_METRICS_PROTOCOL` to set/override the protocol for + traces or metrics. +- `--otlp-endpoint, OTEL_EXPORTER_OTLP_ENDPOINT` to set the OTLP endpoint + - defaults: gRPC `localhost:4317` and HTTP `localhost:4318` + - `--otlp-traces-endpoint, OTEL_EXPORTER_OTLP_TRACES_ENDPOINT` and `--otlp-metrics-endpoint, + OTEL_EXPORTER_OTLP_METRICS_ENDPOINT` to set/override the endpoint for traces or metrics + - default to `localhost:4317` for gRPC for all data types, and `localhost:4318/v1/traces`, or `metrics`, for HTTP +- `--otlp-export, OTLP_EXPORT=traces,metrics` to specify a list of which data types to export to the OTLP exporters, separated by a + comma. The possible values are traces, metrics, or both. The OTLP standard specifies three data types (`traces`, + `metrics`, `logs`) but since we use the recommended `tracing` crate for logging, we can export logs as traces and + not use the separate (more experimental) logging export pipeline. +- `RUST_LOG=error|warn|info|debug|trace` to set the Rust log level. +- `--log-format, LOG_FORMAT=full|compact|json|pretty` to set the logging format for logs printed to stdout. All logs collected as + traces by OTLP will just be structured traces, this argument is specifically for stdout. Default is `full`. +- `--quiet, -q` to silence logging to stdout. If `OTLP_EXPORT=traces` is still provided, all logs can still be viewed + as traces from an observability backend. + +### Cross-service tracing +The orchestrator and client services will be able to consolidate telemetry and share observability through a common +configuration and backends. This will be made possible through the use of the OTLP standard as well as through the +propagation of the trace context through requests across services using the standardized `traceparent` header. The +orchestrator will be expected to initialize a new trace for an incoming request and pass `traceparent` headers +corresponding to this trace to any requests outgoing to clients, and similarly, the orchestrator will expect the client +to provide a `traceparent` header in the response. The orchestrator will not propagate the `traceparent` to outgoing +responses back to the end user (or expect `traceparent` in incoming requests) for security reasons. + +## Status + +Accepted + +## Consequences + +- The orchestrator and client services have a common standard to conform to for telemetry, allowing for traceability + across different services. There does not exist any other attempts at telemetry standardization that is as widely + accepted as OpenTelemetry, or have the same level of support from existing observability and monitoring services. +- The deployment of the orchestrator must be configured with telemetry service(s) listening for telemetry exported on + the specified endpoint(s). An [OTEL collector](https://opentelemetry.io/docs/collector/) service can be used to + collect and propagate the telemetry data, or the export endpoint(s) can be listened to directly by any backend that + supports OTLP (e.g. Jaeger). +- The orchestrator and client services do not need to be concerned with specific observability backends, the OTEL + collector and OTLP standard can be used to export telemetry data to a variety of backends including Jaeger, + Prometheus, Grafana, and Instana, as well to OpenShift natively through the OpenTelemetryCollector CRD. +- Using the `tracing` crate in Rust for logging will treat logs as traces, allowing the orchestrator to export logs + through the trace provider (with OTLP exporter), simplifying the implementation and avoiding use of the logging + provider which is still considered experimental in many contexts (it exists for compatibility with non `tracing` + logging libraries). +- For stdout, the new `--log-format` and `--quiet` arguments add more configurability to format or silence logging. +- The integration of the OpenTelemetry API/SDK into the stack is not trivial, and the OpenTelemetry crates will incur + additional compile time to the orchestrator. +- The OpenTelemetry API/SDK and OTLP standard are new and still evolving, and the orchestrator will need to keep up + with changes in the OpenTelemetry ecosystem, there could be occasional breaking changes that will need addressing. \ No newline at end of file diff --git a/docs/architecture/adrs/008-streaming-orchestrator-endpoints.md b/docs/architecture/adrs/008-streaming-orchestrator-endpoints.md new file mode 100644 index 00000000..418c25ac --- /dev/null +++ b/docs/architecture/adrs/008-streaming-orchestrator-endpoints.md @@ -0,0 +1,42 @@ +# ADR 008: Streaming orchestrator endpoints + +This ADR documents the patterns and behavior expected for streaming orchestrator endpoints. + +The orchestrator API can be found [at these github pages](https://foundation-model-stack.github.io/fms-guardrails-orchestrator/). + +## Motivation + +In [ADR 004](./004-orchestrator-input-only-api-design.md), the design of "input only" detection endpoints was detailed. Currently, those endpoints could only support the "unary" case, where the entire input text is available upfront. For flexibility (example: text is streamed from a generative model that may be available but uncallable through the endpoints with generation), users may still want to call detections on streamed input text. + +The orchestrator will then need to support "bidirectional streaming" endpoints, where text (whether tokens, words, sentences) is streamed in, detectors are invoked (and call their respective chunkers, using bidirectional streaming), and text processed with detectors including potential detections is streamed back to the user. + + +## Decisions + +### Server streaming or endpoint output streaming +"Server streaming" endpoints existed already prior to the writing of this particular ADR. Streaming response aggregation behavior is documented in [ADR 002](./002-streaming-response-aggregation.md). Data will continue to be streamed back with `data` events, with errors included as `event: error` per the [SSE event format](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format). + +Parameters in each response event such as `start_index` and `processed_index` will indicate to the user how much of the input stream has been processed for detections, as there might not necessarily be results like positive `detections` for certain portions of the input stream. The `start_index` and `processed_index` will be relative to the entire stream. + +### Client streaming or endpoint input streaming +- Any information needed for an entire request, like `detectors` that any detection endpoints will work on, will be expected to be present in the first event of a stream. The structure of stream events expected will be documented for each endpoint. + - An alternate consideration was using query or path parameters for information needed for an entire request, like `detectors`, but this would be complicated for the nesting that `detectors` require currently, with a mapping of each detector to dictionary parameters. + - Another alternate consideration was expecting multipart requests, one part with information for the entire request like `detectors` and another part with individual stream events. However, here the content type accepted by the request would have to change. +- Stream closing will be the expected indication that stream events have ended. + - An alternate consideration is an explicit "end of stream" request message for each endpoint, for the user to indicate the connection should be closed. For example for the OpenAI chat completions API, this looks like a `[DONE]` event. The downside here is that this particular event's contents will have to be identified and processed differently from other events. + +### Separate streaming detection endpoints + +To be clear to users, we will start with endpoints that indicate `stream` in the endpoint name. We want to avoid adding `stream` parameters in the request body since this will increase the maintenance of parameters on each request event in the streaming case. Additionally, as detailed earlier, stream endpoint responses will tend to have additional or potentially different fields than their unary counterparts. This point can be altered based on sufficient user feedback. + +NOTE: This ADR will not prescribe implementation details, but while the underlying implementation _could_ use [websockets](https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API), we are explicitly not following the patterns of some websocket APIs that require connecting and disconnecting. + +## Consequences + +- Stream detection endpoints will be separate from current "unary" ones that take entire inputs and return one response. Users then must change endpoints for this different use case. +- The orchestrator can support input or client streaming in a consistent manner. This will enable orchestrator users that may want to stream input content from other sources, like their own generative model. +- Users have to be aware that for input streaming, the first event may need to contain more information necessary for the endpoint. Thus the event message structure may not be exactly the same across events in the stream. + +## Status + +Proposed diff --git a/src/args.rs b/src/args.rs new file mode 100644 index 00000000..b5b38f03 --- /dev/null +++ b/src/args.rs @@ -0,0 +1,214 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +use std::{fmt::Display, path::PathBuf}; + +use clap::Parser; +use tracing::{error, warn}; + +#[derive(Parser, Debug, Clone)] +#[clap(author, version, about, long_about = None)] +pub struct Args { + #[clap(default_value = "8033", long, env)] + pub http_port: u16, + #[clap(default_value = "8034", long, env)] + pub health_http_port: u16, + #[clap( + default_value = "config/config.yaml", + long, + env = "ORCHESTRATOR_CONFIG" + )] + pub config_path: PathBuf, + #[clap(long, env)] + pub tls_cert_path: Option, + #[clap(long, env)] + pub tls_key_path: Option, + #[clap(long, env)] + pub tls_client_ca_cert_path: Option, + #[clap(default_value = "false", long, env)] + pub start_up_health_check: bool, + #[clap(long, env, value_delimiter = ',')] + pub otlp_export: Vec, + #[clap(default_value_t = LogFormat::default(), long, env)] + pub log_format: LogFormat, + #[clap(default_value_t = false, long, short, env)] + pub quiet: bool, + #[clap(default_value = "fms_guardrails_orchestr8", long, env)] + pub otlp_service_name: String, + #[clap(long, env = "OTEL_EXPORTER_OTLP_ENDPOINT")] + pub otlp_endpoint: Option, + #[clap(long, env = "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT")] + pub otlp_traces_endpoint: Option, + #[clap(long, env = "OTEL_EXPORTER_OTLP_METRICS_ENDPOINT")] + pub otlp_metrics_endpoint: Option, + #[clap( + default_value_t = OtlpProtocol::Grpc, + long, + env = "OTEL_EXPORTER_OTLP_PROTOCOL" + )] + pub otlp_protocol: OtlpProtocol, + #[clap(long, env = "OTEL_EXPORTER_OTLP_TRACES_PROTOCOL")] + pub otlp_traces_protocol: Option, + #[clap(long, env = "OTEL_EXPORTER_OTLP_METRICS_PROTOCOL")] + pub otlp_metrics_protocol: Option, + // TODO: Add timeout and header OTLP variables +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum OtlpExport { + Traces, + Metrics, +} + +impl Display for OtlpExport { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + OtlpExport::Traces => write!(f, "traces"), + OtlpExport::Metrics => write!(f, "metrics"), + } + } +} + +impl From for OtlpExport { + fn from(s: String) -> Self { + match s.to_lowercase().as_str() { + "traces" => OtlpExport::Traces, + "metrics" => OtlpExport::Metrics, + _ => panic!( + "Invalid OTLP export type {}, orchestrator only supports exporting traces and metrics via OTLP", + s + ), + } + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub enum OtlpProtocol { + #[default] + Grpc, + Http, +} + +impl Display for OtlpProtocol { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + OtlpProtocol::Grpc => write!(f, "grpc"), + OtlpProtocol::Http => write!(f, "http"), + } + } +} + +impl From for OtlpProtocol { + fn from(s: String) -> Self { + match s.to_lowercase().as_str() { + "grpc" => OtlpProtocol::Grpc, + "http" => OtlpProtocol::Http, + _ => { + error!( + "Invalid OTLP protocol {}, defaulting to {}", + s, + OtlpProtocol::default() + ); + OtlpProtocol::default() + } + } + } +} + +impl OtlpProtocol { + pub fn default_endpoint(&self) -> &str { + match self { + OtlpProtocol::Grpc => "http://localhost:4317", + OtlpProtocol::Http => "http://localhost:4318", + } + } +} + +#[derive(Debug, Clone, Copy, Default, PartialEq)] +pub enum LogFormat { + #[default] + Full, + Compact, + Pretty, + JSON, +} + +impl Display for LogFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + LogFormat::Full => write!(f, "full"), + LogFormat::Compact => write!(f, "compact"), + LogFormat::Pretty => write!(f, "pretty"), + LogFormat::JSON => write!(f, "json"), + } + } +} + +impl From for LogFormat { + fn from(s: String) -> Self { + match s.to_lowercase().as_str() { + "full" => LogFormat::Full, + "compact" => LogFormat::Compact, + "pretty" => LogFormat::Pretty, + "json" => LogFormat::JSON, + _ => { + warn!( + "Invalid trace format {}, defaulting to {}", + s, + LogFormat::default() + ); + LogFormat::default() + } + } + } +} + +#[derive(Debug, Clone)] +pub struct TracingConfig { + pub service_name: String, + pub traces: Option<(OtlpProtocol, String)>, + pub metrics: Option<(OtlpProtocol, String)>, + pub log_format: LogFormat, + pub quiet: bool, +} + +impl From for TracingConfig { + fn from(args: Args) -> Self { + let otlp_protocol = args.otlp_protocol; + let otlp_endpoint = args + .otlp_endpoint + .unwrap_or(otlp_protocol.default_endpoint().to_string()); + let otlp_traces_endpoint = args.otlp_traces_endpoint.unwrap_or(otlp_endpoint.clone()); + let otlp_metrics_endpoint = args.otlp_metrics_endpoint.unwrap_or(otlp_endpoint.clone()); + let otlp_traces_protocol = args.otlp_traces_protocol.unwrap_or(otlp_protocol); + let otlp_metrics_protocol = args.otlp_metrics_protocol.unwrap_or(otlp_protocol); + + TracingConfig { + service_name: args.otlp_service_name, + traces: match args.otlp_export.contains(&OtlpExport::Traces) { + true => Some((otlp_traces_protocol, otlp_traces_endpoint)), + false => None, + }, + metrics: match args.otlp_export.contains(&OtlpExport::Metrics) { + true => Some((otlp_metrics_protocol, otlp_metrics_endpoint)), + false => None, + }, + log_format: args.log_format, + quiet: args.quiet, + } + } +} diff --git a/src/clients.rs b/src/clients.rs index 4e40ae34..b7964934 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -16,29 +16,38 @@ */ #![allow(dead_code)] -// Import error for adding `source` trait -use std::{collections::HashMap, error::Error as _, fmt::Display, pin::Pin, time::Duration}; +use std::{ + any::TypeId, + collections::{hash_map, HashMap}, + pin::Pin, + time::Duration, +}; -use futures::{future::join_all, Stream}; +use async_trait::async_trait; +use axum::http::{Extensions, HeaderMap}; +use futures::Stream; use ginepro::LoadBalancedChannel; -use reqwest::{Response, StatusCode}; use tokio::{fs::File, io::AsyncReadExt}; -use tracing::error; +use tonic::{metadata::MetadataMap, Request}; +use tracing::{debug, instrument}; use url::Url; use crate::{ config::{ServiceConfig, Tls}, - health::{HealthCheck, HealthCheckResult, HealthStatus, OptionalHealthCheckResponseBody}, + health::HealthCheckResult, + tracing_utils::with_traceparent_header, }; +pub mod errors; +pub use errors::Error; + +pub mod http; +pub use http::HttpClient; + pub mod chunker; -pub use chunker::ChunkerClient; pub mod detector; -pub use detector::DetectorClient; - -pub mod generation; -pub use generation::GenerationClient; +pub use detector::TextContentsDetectorClient; pub mod tgis; pub use tgis::TgisClient; @@ -46,358 +55,319 @@ pub use tgis::TgisClient; pub mod nlp; pub use nlp::NlpClient; -pub const DEFAULT_TGIS_PORT: u16 = 8033; -pub const DEFAULT_CAIKIT_NLP_PORT: u16 = 8085; -pub const DEFAULT_CHUNKER_PORT: u16 = 8085; -pub const DEFAULT_DETECTOR_PORT: u16 = 8080; -pub const COMMON_ROUTER_KEY: &str = "common-router"; +pub mod generation; +pub use generation::GenerationClient; + +pub mod nlp_http; +pub use nlp_http::NlpClientHttp; + +pub mod openai; + const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(60); const DEFAULT_REQUEST_TIMEOUT_SEC: u64 = 600; pub type BoxStream = Pin + Send>>; -/// Client errors. -#[derive(Debug, Clone, PartialEq, thiserror::Error)] -pub enum Error { - #[error("{}", .message)] - Grpc { code: StatusCode, message: String }, - #[error("{}", .message)] - Http { code: StatusCode, message: String }, - #[error("model not found: {model_id}")] - ModelNotFound { model_id: String }, +mod private { + pub struct Seal; } -impl Error { - /// Returns status code. - pub fn status_code(&self) -> StatusCode { - match self { - // Return equivalent http status code for grpc status code - Error::Grpc { code, .. } => *code, - // Return http status code for error responses - // and 500 for other errors - Error::Http { code, .. } => *code, - // Return 404 for model not found - Error::ModelNotFound { .. } => StatusCode::NOT_FOUND, - } +#[async_trait] +pub trait Client: Send + Sync + 'static { + /// Returns the name of the client type. + fn name(&self) -> &str; + + /// Returns the `TypeId` of the client type. Sealed to prevent overrides. + fn type_id(&self, _: private::Seal) -> TypeId { + TypeId::of::() } + + /// Performs a client health check. + async fn health(&self) -> HealthCheckResult; } -impl From for Error { - fn from(value: reqwest::Error) -> Self { - // Log lower level source of error. - // Examples: - // 1. client error (Connect) // Cases like connection error, wrong port etc. - // 2. client error (SendRequest) // Cases like cert issues - error!( - "http request failed. Source: {}", - value.source().unwrap().to_string() - ); - // Return http status code for error responses - // and 500 for other errors - let code = match value.status() { - Some(code) => code, - None => StatusCode::INTERNAL_SERVER_ERROR, - }; - Self::Http { - code, - message: value.to_string(), - } +impl dyn Client { + pub fn is(&self) -> bool { + TypeId::of::() == self.type_id(private::Seal) } -} -impl From for Error { - fn from(value: tonic::Status) -> Self { - use tonic::Code::*; - // Return equivalent http status code for grpc status code - let code = match value.code() { - InvalidArgument => StatusCode::BAD_REQUEST, - Internal => StatusCode::INTERNAL_SERVER_ERROR, - NotFound => StatusCode::NOT_FOUND, - DeadlineExceeded => StatusCode::REQUEST_TIMEOUT, - Unimplemented => StatusCode::NOT_IMPLEMENTED, - Unauthenticated => StatusCode::UNAUTHORIZED, - PermissionDenied => StatusCode::FORBIDDEN, - Unavailable => StatusCode::SERVICE_UNAVAILABLE, - Ok => StatusCode::OK, - _ => StatusCode::INTERNAL_SERVER_ERROR, - }; - Self::Grpc { - code, - message: value.message().to_string(), + pub fn downcast(self: Box) -> Result, Box> { + if (*self).is::() { + let ptr = Box::into_raw(self) as *mut T; + // SAFETY: guaranteed by `is` + unsafe { Ok(Box::from_raw(ptr)) } + } else { + Err(self) } } -} -#[derive(Debug, Clone, PartialEq)] -pub enum ClientCode { - Http(StatusCode), - Grpc(tonic::Code), -} + pub fn downcast_ref(&self) -> Option<&T> { + if (*self).is::() { + let ptr = self as *const dyn Client as *const T; + // SAFETY: guaranteed by `is` + unsafe { Some(&*ptr) } + } else { + None + } + } -impl Display for ClientCode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ClientCode::Http(code) => write!(f, "HTTP {}", code), - ClientCode::Grpc(code) => write!(f, "gRPC {:?} {}", code, code), + pub fn downcast_mut(&mut self) -> Option<&mut T> { + if (*self).is::() { + let ptr = self as *mut dyn Client as *mut T; + // SAFETY: guaranteed by `is` + unsafe { Some(&mut *ptr) } + } else { + None } } } -#[derive(Clone)] -pub struct HttpClient { - base_url: Url, - health_url: Url, - client: reqwest::Client, -} +/// A map containing different types of clients. +#[derive(Default)] +pub struct ClientMap(HashMap>); -impl HttpClient { - pub fn new(base_url: Url, client: reqwest::Client) -> Self { - let health_url = extract_base_url(&base_url).join("health").unwrap(); - Self { - base_url, - health_url, - client, - } +impl ClientMap { + /// Creates an empty `ClientMap`. + #[inline] + pub fn new() -> Self { + Self(HashMap::new()) } - pub fn base_url(&self) -> &Url { - &self.base_url + /// Inserts a client into the map. + #[inline] + pub fn insert(&mut self, key: String, value: V) { + self.0.insert(key, Box::new(value)); } - /// This is sectioned off to allow for testing. - pub(super) async fn http_response_to_health_check_result( - res: Result, - ) -> HealthCheckResult { - match res { - Ok(response) => { - if response.status() == StatusCode::OK { - if let Ok(body) = response.json::().await { - // If the service provided a body, we only anticipate a minimal health status and optional reason. - HealthCheckResult { - health_status: body.health_status.clone(), - response_code: ClientCode::Http(StatusCode::OK), - reason: match body.health_status { - HealthStatus::Healthy => None, - _ => body.reason, - }, - } - } else { - // If the service did not provide a body, we assume it is healthy. - HealthCheckResult { - health_status: HealthStatus::Healthy, - response_code: ClientCode::Http(StatusCode::OK), - reason: None, - } - } - } else { - HealthCheckResult { - // The most we can presume is that 5xx errors are likely indicating service issues, implying the service is unhealthy. - // and that 4xx errors are more likely indicating health check failures, i.e. due to configuration/implementation issues. - // Regardless we can't be certain, so the reason is also provided. - // TODO: We will likely circle back to re-evaluate this logic in the future - // when we know more about how the client health results will be used. - health_status: if response.status().as_u16() >= 500 - && response.status().as_u16() < 600 - { - HealthStatus::Unhealthy - } else if response.status().as_u16() >= 400 - && response.status().as_u16() < 500 - { - HealthStatus::Unknown - } else { - error!( - "unexpected http health check status code: {}", - response.status() - ); - HealthStatus::Unknown - }, - response_code: ClientCode::Http(response.status()), - reason: Some(format!( - "{}{}", - response.error_for_status_ref().unwrap_err(), - response - .text() - .await - .map(|s| if s.is_empty() { - "".to_string() - } else { - format!(": {}", s) - }) - .unwrap_or("".to_string()) - )), - } - } - } - Err(e) => { - error!("error checking health: {}", e); - HealthCheckResult { - health_status: HealthStatus::Unknown, - response_code: ClientCode::Http( - e.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), - ), - reason: Some(e.to_string()), - } - } - } + /// Returns a reference to the client trait object. + #[inline] + pub fn get(&self, key: &str) -> Option<&dyn Client> { + self.0.get(key).map(|v| v.as_ref()) } -} -impl HealthCheck for HttpClient { - async fn check(&self) -> HealthCheckResult { - let res = self.get(self.health_url.clone()).send().await; - Self::http_response_to_health_check_result(res).await + /// Returns a mutable reference to the client trait object. + #[inline] + pub fn get_mut(&mut self, key: &str) -> Option<&mut dyn Client> { + self.0.get_mut(key).map(|v| v.as_mut()) + } + + /// Downcasts and returns a reference to the concrete client type. + #[inline] + pub fn get_as(&self, key: &str) -> Option<&V> { + self.0.get(key)?.downcast_ref::() } -} -impl std::ops::Deref for HttpClient { - type Target = reqwest::Client; + /// Downcasts and returns a mutable reference to the concrete client type. + #[inline] + pub fn get_mut_as(&mut self, key: &str) -> Option<&mut V> { + self.0.get_mut(key)?.downcast_mut::() + } - fn deref(&self) -> &Self::Target { - &self.client + /// Removes a client from the map. + #[inline] + pub fn remove(&mut self, key: &str) -> Option> { + self.0.remove(key) + } + + /// An iterator visiting all key-value pairs in arbitrary order. + #[inline] + pub fn iter(&self) -> hash_map::Iter<'_, String, Box> { + self.0.iter() + } + + /// An iterator visiting all keys in arbitrary order. + #[inline] + pub fn keys(&self) -> hash_map::Keys<'_, String, Box> { + self.0.keys() + } + + /// An iterator visiting all values in arbitrary order. + #[inline] + pub fn values(&self) -> hash_map::Values<'_, String, Box> { + self.0.values() + } + + /// Returns the number of elements in the map. + #[inline] + pub fn len(&self) -> usize { + self.0.len() + } + + /// Returns `true` if the map contains no elements. + #[inline] + pub fn is_empty(&self) -> bool { + self.0.is_empty() } } -pub async fn create_http_clients( - default_port: u16, - config: &[(String, ServiceConfig)], -) -> HashMap { - let clients = config - .iter() - .map(|(name, service_config)| async move { - let port = service_config.port.unwrap_or(default_port); - let mut base_url = Url::parse(&service_config.hostname).unwrap(); - base_url.set_port(Some(port)).unwrap(); - let request_timeout = Duration::from_secs( - service_config - .request_timeout - .unwrap_or(DEFAULT_REQUEST_TIMEOUT_SEC), - ); - let mut builder = reqwest::ClientBuilder::new() - .connect_timeout(DEFAULT_CONNECT_TIMEOUT) - .timeout(request_timeout); - if let Some(Tls::Config(tls_config)) = &service_config.tls { - let mut cert_buf = Vec::new(); - let cert_path = tls_config.cert_path.as_ref().unwrap().as_path(); - File::open(cert_path) - .await - .unwrap_or_else(|error| { - panic!("error reading cert from {cert_path:?}: {error}") - }) - .read_to_end(&mut cert_buf) - .await - .unwrap(); - - if let Some(key_path) = &tls_config.key_path { - File::open(key_path) - .await - .unwrap_or_else(|error| { - panic!("error reading key from {key_path:?}: {error}") - }) - .read_to_end(&mut cert_buf) - .await - .unwrap(); - } - let identity = reqwest::Identity::from_pem(&cert_buf).unwrap_or_else(|error| { - panic!("error parsing bundled client certificate: {error}") - }); +#[instrument(skip_all, fields(hostname = service_config.hostname))] +pub async fn create_http_client(default_port: u16, service_config: &ServiceConfig) -> HttpClient { + let port = service_config.port.unwrap_or(default_port); + let protocol = match service_config.tls { + Some(_) => "https", + None => "http", + }; + let mut base_url = Url::parse(&format!("{}://{}", protocol, &service_config.hostname)) + .unwrap_or_else(|e| panic!("error parsing base url: {}", e)); + base_url + .set_port(Some(port)) + .unwrap_or_else(|_| panic!("error setting port: {}", port)); + debug!(%base_url, "creating HTTP client"); + let request_timeout = Duration::from_secs( + service_config + .request_timeout + .unwrap_or(DEFAULT_REQUEST_TIMEOUT_SEC), + ); + let mut builder = reqwest::ClientBuilder::new() + .connect_timeout(DEFAULT_CONNECT_TIMEOUT) + .timeout(request_timeout); + if let Some(Tls::Config(tls_config)) = &service_config.tls { + let mut cert_buf = Vec::new(); + let cert_path = tls_config.cert_path.as_ref().unwrap().as_path(); + File::open(cert_path) + .await + .unwrap_or_else(|error| panic!("error reading cert from {cert_path:?}: {error}")) + .read_to_end(&mut cert_buf) + .await + .unwrap(); + + if let Some(key_path) = &tls_config.key_path { + File::open(key_path) + .await + .unwrap_or_else(|error| panic!("error reading key from {key_path:?}: {error}")) + .read_to_end(&mut cert_buf) + .await + .unwrap(); + } + let identity = reqwest::Identity::from_pem(&cert_buf) + .unwrap_or_else(|error| panic!("error parsing bundled client certificate: {error}")); - builder = builder.use_rustls_tls().identity(identity); - builder = builder.danger_accept_invalid_certs(tls_config.insecure.unwrap_or(false)); - - if let Some(client_ca_cert_path) = &tls_config.client_ca_cert_path { - let ca_cert = - tokio::fs::read(client_ca_cert_path) - .await - .unwrap_or_else(|error| { - panic!("error reading cert from {client_ca_cert_path:?}: {error}") - }); - let cacert = reqwest::Certificate::from_pem(&ca_cert) - .unwrap_or_else(|error| panic!("error parsing ca cert: {error}")); - builder = builder.add_root_certificate(cacert) - } - } - let client = builder - .build() - .unwrap_or_else(|error| panic!("error creating http client for {name}: {error}")); - let client = HttpClient::new(base_url, client); - (name.clone(), client) - }) - .collect::>(); - join_all(clients).await.into_iter().collect() + builder = builder.use_rustls_tls().identity(identity); + builder = builder.danger_accept_invalid_certs(tls_config.insecure.unwrap_or(false)); + + if let Some(client_ca_cert_path) = &tls_config.client_ca_cert_path { + let ca_cert = tokio::fs::read(client_ca_cert_path) + .await + .unwrap_or_else(|error| { + panic!("error reading cert from {client_ca_cert_path:?}: {error}") + }); + let cacert = reqwest::Certificate::from_pem(&ca_cert) + .unwrap_or_else(|error| panic!("error parsing ca cert: {error}")); + builder = builder.add_root_certificate(cacert) + } + } + let client = builder + .build() + .unwrap_or_else(|error| panic!("error creating http client: {error}")); + HttpClient::new(base_url, client) } -async fn create_grpc_clients( +#[instrument(skip_all, fields(hostname = service_config.hostname))] +pub async fn create_grpc_client( default_port: u16, - config: &[(String, ServiceConfig)], + service_config: &ServiceConfig, new: fn(LoadBalancedChannel) -> C, -) -> HashMap { - let clients = config - .iter() - .map(|(name, service_config)| async move { - let request_timeout = Duration::from_secs(service_config.request_timeout.unwrap_or(DEFAULT_REQUEST_TIMEOUT_SEC)); - let mut builder = LoadBalancedChannel::builder(( - service_config.hostname.clone(), - service_config.port.unwrap_or(default_port), - )) - .connect_timeout(DEFAULT_CONNECT_TIMEOUT) - .timeout(request_timeout); - - let client_tls_config = if let Some(Tls::Config(tls_config)) = &service_config.tls { - let cert_path = tls_config.cert_path.as_ref().unwrap().as_path(); - let key_path = tls_config.key_path.as_ref().unwrap().as_path(); - let cert_pem = tokio::fs::read(cert_path) - .await - .unwrap_or_else(|error| panic!("error reading cert from {cert_path:?}: {error}")); - let key_pem = tokio::fs::read(key_path) +) -> C { + let port = service_config.port.unwrap_or(default_port); + let protocol = match service_config.tls { + Some(_) => "https", + None => "http", + }; + let mut base_url = Url::parse(&format!("{}://{}", protocol, &service_config.hostname)).unwrap(); + base_url.set_port(Some(port)).unwrap(); + debug!(%base_url, "creating gRPC client"); + let request_timeout = Duration::from_secs( + service_config + .request_timeout + .unwrap_or(DEFAULT_REQUEST_TIMEOUT_SEC), + ); + let mut builder = LoadBalancedChannel::builder((service_config.hostname.clone(), port)) + .connect_timeout(DEFAULT_CONNECT_TIMEOUT) + .timeout(request_timeout); + + let client_tls_config = if let Some(Tls::Config(tls_config)) = &service_config.tls { + let cert_path = tls_config.cert_path.as_ref().unwrap().as_path(); + let key_path = tls_config.key_path.as_ref().unwrap().as_path(); + let cert_pem = tokio::fs::read(cert_path) + .await + .unwrap_or_else(|error| panic!("error reading cert from {cert_path:?}: {error}")); + let key_pem = tokio::fs::read(key_path) + .await + .unwrap_or_else(|error| panic!("error reading key from {key_path:?}: {error}")); + let identity = tonic::transport::Identity::from_pem(cert_pem, key_pem); + let mut client_tls_config = tonic::transport::ClientTlsConfig::new() + .identity(identity) + .with_native_roots() + .with_webpki_roots(); + if let Some(client_ca_cert_path) = &tls_config.client_ca_cert_path { + let client_ca_cert_pem = + tokio::fs::read(client_ca_cert_path) .await - .unwrap_or_else(|error| panic!("error reading key from {key_path:?}: {error}")); - let identity = tonic::transport::Identity::from_pem(cert_pem, key_pem); - let mut client_tls_config = - tonic::transport::ClientTlsConfig::new().identity(identity).with_native_roots().with_webpki_roots(); - if let Some(client_ca_cert_path) = &tls_config.client_ca_cert_path { - let client_ca_cert_pem = tokio::fs::read(client_ca_cert_path) - .await - .unwrap_or_else(|error| { - panic!("error reading client ca cert from {client_ca_cert_path:?}: {error}") - }); - client_tls_config = client_tls_config.ca_certificate( - tonic::transport::Certificate::from_pem(client_ca_cert_pem), - ); - } - Some(client_tls_config) - } else { - None - }; - if let Some(client_tls_config) = client_tls_config { - builder = builder.with_tls(client_tls_config); - } - let channel = builder.channel().await.unwrap_or_else(|error| panic!("error creating grpc client for {name}: {error}")); - (name.clone(), new(channel)) - }) - .collect::>(); - join_all(clients).await.into_iter().collect() + .unwrap_or_else(|error| { + panic!("error reading client ca cert from {client_ca_cert_path:?}: {error}") + }); + client_tls_config = client_tls_config + .ca_certificate(tonic::transport::Certificate::from_pem(client_ca_cert_pem)); + } + Some(client_tls_config) + } else { + None + }; + if let Some(client_tls_config) = client_tls_config { + builder = builder.with_tls(client_tls_config); + } + let channel = builder + .channel() + .await + .unwrap_or_else(|error| panic!("error creating grpc client: {error}")); + new(channel) } -/// Extracts a base url from a url including path segments. -fn extract_base_url(url: &Url) -> Url { - let mut url = url.clone(); - match url.path_segments_mut() { - Ok(mut path) => { - path.clear(); - } - Err(_) => { - panic!("url cannot be a base"); - } +/// Returns `true` if hostname is valid according to [IETF RFC 1123](https://tools.ietf.org/html/rfc1123). +/// +/// Conditions: +/// - It does not start or end with `-` or `.`. +/// - It does not contain any characters outside of the alphanumeric range, except for `-` and `.`. +/// - It is not empty. +/// - It is 253 or fewer characters. +/// - Its labels (characters separated by `.`) are not empty. +/// - Its labels are 63 or fewer characters. +/// - Its labels do not start or end with '-' or '.'. +pub fn is_valid_hostname(hostname: &str) -> bool { + fn is_valid_char(byte: u8) -> bool { + byte.is_ascii_lowercase() + || byte.is_ascii_uppercase() + || byte.is_ascii_digit() + || byte == b'-' + || byte == b'.' } - url + !(hostname.bytes().any(|byte| !is_valid_char(byte)) + || hostname.split('.').any(|label| { + label.is_empty() || label.len() > 63 || label.starts_with('-') || label.ends_with('-') + }) + || hostname.is_empty() + || hostname.len() > 253) +} + +/// Turns a gRPC client request body of type `T` and header map into a `tonic::Request`. +/// Will also inject the current `traceparent` header into the request based on the current span. +fn grpc_request_with_headers(request: T, headers: HeaderMap) -> Request { + let headers = with_traceparent_header(headers); + let metadata = MetadataMap::from_headers(headers); + Request::from_parts(metadata, Extensions::new(), request) } #[cfg(test)] mod tests { - use hyper::http; + use errors::grpc_to_http_code; + use hyper::{http, StatusCode}; + use reqwest::Response; use super::*; - use crate::pb::grpc::health::v1::{health_check_response::ServingStatus, HealthCheckResponse}; + use crate::{ + health::{HealthCheckResult, HealthStatus}, + pb::grpc::health::v1::{health_check_response::ServingStatus, HealthCheckResponse}, + }; async fn mock_http_response( status: StatusCode, @@ -429,83 +399,69 @@ mod tests { // READY responses from HTTP 200 OK with or without reason let response = [ (StatusCode::OK, r#"{}"#), - (StatusCode::OK, r#"{ "health_status": "HEALTHY" }"#), + (StatusCode::OK, r#"{ "status": "HEALTHY" }"#), + (StatusCode::OK, r#"{ "status": "meaningless status" }"#), ( StatusCode::OK, - r#"{ "health_status": "meaningless status" }"#, - ), - ( - StatusCode::OK, - r#"{ "health_status": "HEALTHY", "reason": "needless reason" }"#, + r#"{ "status": "HEALTHY", "reason": "needless reason" }"#, ), ]; for (status, body) in response.iter() { let response = mock_http_response(*status, body).await; let result = HttpClient::http_response_to_health_check_result(response).await; - assert_eq!(result.health_status, HealthStatus::Healthy); - assert_eq!(result.response_code, ClientCode::Http(StatusCode::OK)); + assert_eq!(result.status, HealthStatus::Healthy); + assert_eq!(result.code, StatusCode::OK); assert_eq!(result.reason, None); let serialized = serde_json::to_string(&result).unwrap(); - assert_eq!(serialized, r#""HEALTHY""#); + assert_eq!(serialized, r#"{"status":"HEALTHY"}"#); } // NOT_READY response from HTTP 200 OK without reason - let response = - mock_http_response(StatusCode::OK, r#"{ "health_status": "UNHEALTHY" }"#).await; + let response = mock_http_response(StatusCode::OK, r#"{ "status": "UNHEALTHY" }"#).await; let result = HttpClient::http_response_to_health_check_result(response).await; - assert_eq!(result.health_status, HealthStatus::Unhealthy); - assert_eq!(result.response_code, ClientCode::Http(StatusCode::OK)); + assert_eq!(result.status, HealthStatus::Unhealthy); + assert_eq!(result.code, StatusCode::OK); assert_eq!(result.reason, None); let serialized = serde_json::to_string(&result).unwrap(); - assert_eq!( - serialized, - r#"{"health_status":"UNHEALTHY","response_code":"HTTP 200 OK"}"# - ); + assert_eq!(serialized, r#"{"status":"UNHEALTHY"}"#); // UNKNOWN response from HTTP 200 OK without reason - let response = - mock_http_response(StatusCode::OK, r#"{ "health_status": "UNKNOWN" }"#).await; + let response = mock_http_response(StatusCode::OK, r#"{ "status": "UNKNOWN" }"#).await; let result = HttpClient::http_response_to_health_check_result(response).await; - assert_eq!(result.health_status, HealthStatus::Unknown); - assert_eq!(result.response_code, ClientCode::Http(StatusCode::OK)); + assert_eq!(result.status, HealthStatus::Unknown); + assert_eq!(result.code, StatusCode::OK); assert_eq!(result.reason, None); let serialized = serde_json::to_string(&result).unwrap(); - assert_eq!( - serialized, - r#"{"health_status":"UNKNOWN","response_code":"HTTP 200 OK"}"# - ); + assert_eq!(serialized, r#"{"status":"UNKNOWN"}"#); // NOT_READY response from HTTP 200 OK with reason let response = mock_http_response( StatusCode::OK, - r#"{ "health_status": "UNHEALTHY", "reason": "some reason" }"#, + r#"{"status": "UNHEALTHY", "reason": "some reason" }"#, ) .await; let result = HttpClient::http_response_to_health_check_result(response).await; - assert_eq!(result.health_status, HealthStatus::Unhealthy); - assert_eq!(result.response_code, ClientCode::Http(StatusCode::OK)); + assert_eq!(result.status, HealthStatus::Unhealthy); + assert_eq!(result.code, StatusCode::OK); assert_eq!(result.reason, Some("some reason".to_string())); let serialized = serde_json::to_string(&result).unwrap(); assert_eq!( serialized, - r#"{"health_status":"UNHEALTHY","response_code":"HTTP 200 OK","reason":"some reason"}"# + r#"{"status":"UNHEALTHY","reason":"some reason"}"# ); // UNKNOWN response from HTTP 200 OK with reason let response = mock_http_response( StatusCode::OK, - r#"{ "health_status": "UNKNOWN", "reason": "some reason" }"#, + r#"{ "status": "UNKNOWN", "reason": "some reason" }"#, ) .await; let result = HttpClient::http_response_to_health_check_result(response).await; - assert_eq!(result.health_status, HealthStatus::Unknown); - assert_eq!(result.response_code, ClientCode::Http(StatusCode::OK)); + assert_eq!(result.status, HealthStatus::Unknown); + assert_eq!(result.code, StatusCode::OK); assert_eq!(result.reason, Some("some reason".to_string())); let serialized = serde_json::to_string(&result).unwrap(); - assert_eq!( - serialized, - r#"{"health_status":"UNKNOWN","response_code":"HTTP 200 OK","reason":"some reason"}"# - ); + assert_eq!(serialized, r#"{"status":"UNKNOWN","reason":"some reason"}"#); // NOT_READY response from HTTP 503 SERVICE UNAVAILABLE with reason let response = mock_http_response( @@ -514,16 +470,13 @@ mod tests { ) .await; let result = HttpClient::http_response_to_health_check_result(response).await; - assert_eq!(result.health_status, HealthStatus::Unhealthy); - assert_eq!( - result.response_code, - ClientCode::Http(StatusCode::SERVICE_UNAVAILABLE) - ); - assert_eq!(result.reason, Some(r#"HTTP status server error (503 Service Unavailable) for url (http://no.url.provided.local/): { "message": "some error message" }"#.to_string())); + assert_eq!(result.status, HealthStatus::Unhealthy); + assert_eq!(result.code, StatusCode::SERVICE_UNAVAILABLE); + assert_eq!(result.reason, Some("Service Unavailable".to_string())); let serialized = serde_json::to_string(&result).unwrap(); assert_eq!( serialized, - r#"{"health_status":"UNHEALTHY","response_code":"HTTP 503 Service Unavailable","reason":"HTTP status server error (503 Service Unavailable) for url (http://no.url.provided.local/): { \"message\": \"some error message\" }"}"# + r#"{"status":"UNHEALTHY","code":503,"reason":"Service Unavailable"}"# ); // UNKNOWN response from HTTP 404 NOT FOUND with reason @@ -533,46 +486,37 @@ mod tests { ) .await; let result = HttpClient::http_response_to_health_check_result(response).await; - assert_eq!(result.health_status, HealthStatus::Unknown); - assert_eq!( - result.response_code, - ClientCode::Http(StatusCode::NOT_FOUND) - ); - assert_eq!(result.reason, Some(r#"HTTP status client error (404 Not Found) for url (http://no.url.provided.local/): { "message": "service not found" }"#.to_string())); + assert_eq!(result.status, HealthStatus::Unknown); + assert_eq!(result.code, StatusCode::NOT_FOUND); + assert_eq!(result.reason, Some("Not Found".to_string())); let serialized = serde_json::to_string(&result).unwrap(); assert_eq!( serialized, - r#"{"health_status":"UNKNOWN","response_code":"HTTP 404 Not Found","reason":"HTTP status client error (404 Not Found) for url (http://no.url.provided.local/): { \"message\": \"service not found\" }"}"# + r#"{"status":"UNKNOWN","code":404,"reason":"Not Found"}"# ); // NOT_READY response from HTTP 500 INTERNAL SERVER ERROR without reason let response = mock_http_response(StatusCode::INTERNAL_SERVER_ERROR, r#""#).await; let result = HttpClient::http_response_to_health_check_result(response).await; - assert_eq!(result.health_status, HealthStatus::Unhealthy); - assert_eq!( - result.response_code, - ClientCode::Http(StatusCode::INTERNAL_SERVER_ERROR) - ); - assert_eq!(result.reason, Some("HTTP status server error (500 Internal Server Error) for url (http://no.url.provided.local/)".to_string())); + assert_eq!(result.status, HealthStatus::Unhealthy); + assert_eq!(result.code, StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!(result.reason, Some("Internal Server Error".to_string())); let serialized = serde_json::to_string(&result).unwrap(); assert_eq!( serialized, - r#"{"health_status":"UNHEALTHY","response_code":"HTTP 500 Internal Server Error","reason":"HTTP status server error (500 Internal Server Error) for url (http://no.url.provided.local/)"}"# + r#"{"status":"UNHEALTHY","code":500,"reason":"Internal Server Error"}"# ); // UNKNOWN response from HTTP 400 BAD REQUEST without reason let response = mock_http_response(StatusCode::BAD_REQUEST, r#""#).await; let result = HttpClient::http_response_to_health_check_result(response).await; - assert_eq!(result.health_status, HealthStatus::Unknown); - assert_eq!( - result.response_code, - ClientCode::Http(StatusCode::BAD_REQUEST) - ); - assert_eq!(result.reason, Some("HTTP status client error (400 Bad Request) for url (http://no.url.provided.local/)".to_string())); + assert_eq!(result.status, HealthStatus::Unknown); + assert_eq!(result.code, StatusCode::BAD_REQUEST); + assert_eq!(result.reason, Some("Bad Request".to_string())); let serialized = serde_json::to_string(&result).unwrap(); assert_eq!( serialized, - r#"{"health_status":"UNKNOWN","response_code":"HTTP 400 Bad Request","reason":"HTTP status client error (400 Bad Request) for url (http://no.url.provided.local/)"}"# + r#"{"status":"UNKNOWN","code":400,"reason":"Bad Request"}"# ); } @@ -581,59 +525,47 @@ mod tests { // READY responses from gRPC 0 OK from serving status 1 SERVING let response = mock_grpc_response(Some(ServingStatus::Serving as i32), None).await; let result = HealthCheckResult::from(response); - assert_eq!(result.health_status, HealthStatus::Healthy); - assert_eq!(result.response_code, ClientCode::Grpc(tonic::Code::Ok)); + assert_eq!(result.status, HealthStatus::Healthy); + assert_eq!(result.code, grpc_to_http_code(tonic::Code::Ok)); assert_eq!(result.reason, None); let serialized = serde_json::to_string(&result).unwrap(); - assert_eq!(serialized, r#""HEALTHY""#); + assert_eq!(serialized, r#"{"status":"HEALTHY"}"#); // NOT_READY response from gRPC 0 OK form serving status 2 NOT_SERVING let response = mock_grpc_response(Some(ServingStatus::NotServing as i32), None).await; let result = HealthCheckResult::from(response); - assert_eq!(result.health_status, HealthStatus::Unhealthy); - assert_eq!(result.response_code, ClientCode::Grpc(tonic::Code::Ok)); - assert_eq!( - result.reason, - Some("from gRPC health check serving status: NOT_SERVING".to_string()) - ); + assert_eq!(result.status, HealthStatus::Unhealthy); + assert_eq!(result.code, grpc_to_http_code(tonic::Code::Ok)); + assert_eq!(result.reason, Some("NOT_SERVING".to_string())); let serialized = serde_json::to_string(&result).unwrap(); assert_eq!( serialized, - r#"{"health_status":"UNHEALTHY","response_code":"gRPC Ok The operation completed successfully","reason":"from gRPC health check serving status: NOT_SERVING"}"# + r#"{"status":"UNHEALTHY","reason":"NOT_SERVING"}"# ); // UNKNOWN response from gRPC 0 OK from serving status 0 UNKNOWN let response = mock_grpc_response(Some(ServingStatus::Unknown as i32), None).await; let result = HealthCheckResult::from(response); - assert_eq!(result.health_status, HealthStatus::Unknown); - assert_eq!(result.response_code, ClientCode::Grpc(tonic::Code::Ok)); - assert_eq!( - result.reason, - Some("from gRPC health check serving status: UNKNOWN".to_string()) - ); + assert_eq!(result.status, HealthStatus::Unknown); + assert_eq!(result.code, grpc_to_http_code(tonic::Code::Ok)); + assert_eq!(result.reason, Some("UNKNOWN".to_string())); let serialized = serde_json::to_string(&result).unwrap(); - assert_eq!( - serialized, - r#"{"health_status":"UNKNOWN","response_code":"gRPC Ok The operation completed successfully","reason":"from gRPC health check serving status: UNKNOWN"}"# - ); + assert_eq!(serialized, r#"{"status":"UNKNOWN","reason":"UNKNOWN"}"#); // UNKNOWN response from gRPC 0 OK from serving status 3 SERVICE_UNKNOWN let response = mock_grpc_response(Some(ServingStatus::ServiceUnknown as i32), None).await; let result = HealthCheckResult::from(response); - assert_eq!(result.health_status, HealthStatus::Unknown); - assert_eq!(result.response_code, ClientCode::Grpc(tonic::Code::Ok)); - assert_eq!( - result.reason, - Some("from gRPC health check serving status: SERVICE_UNKNOWN".to_string()) - ); + assert_eq!(result.status, HealthStatus::Unknown); + assert_eq!(result.code, grpc_to_http_code(tonic::Code::Ok)); + assert_eq!(result.reason, Some("SERVICE_UNKNOWN".to_string())); let serialized = serde_json::to_string(&result).unwrap(); assert_eq!( serialized, - r#"{"health_status":"UNKNOWN","response_code":"gRPC Ok The operation completed successfully","reason":"from gRPC health check serving status: SERVICE_UNKNOWN"}"# + r#"{"status":"UNKNOWN","reason":"SERVICE_UNKNOWN"}"# ); // UNKNOWN response from other gRPC error codes (covering main ones) - let response_codes = [ + let codes = [ tonic::Code::InvalidArgument, tonic::Code::Internal, tonic::Code::NotFound, @@ -642,40 +574,40 @@ mod tests { tonic::Code::PermissionDenied, tonic::Code::Unavailable, ]; - for code in response_codes.iter() { - let status = tonic::Status::new(*code, "some error message"); + for code in codes.into_iter() { + let status = tonic::Status::new(code, "some error message"); + let code = grpc_to_http_code(code); let response = mock_grpc_response(None, Some(status.clone())).await; let result = HealthCheckResult::from(response); - assert_eq!(result.health_status, HealthStatus::Unknown); - assert_eq!(result.response_code, ClientCode::Grpc(*code)); - assert_eq!( - result.reason, - Some(format!("gRPC health check failed: {}", status.clone())) - ); + assert_eq!(result.status, HealthStatus::Unknown); + assert_eq!(result.code, code); + assert_eq!(result.reason, Some("some error message".to_string())); let serialized = serde_json::to_string(&result).unwrap(); assert_eq!( serialized, format!( - r#"{{"health_status":"UNKNOWN","response_code":"gRPC {:?} {}","reason":"gRPC health check failed: status: {:?}, message: \"some error message\", details: [], metadata: MetadataMap {{ headers: {{}} }}"}}"#, - code, code, code - ) + r#"{{"status":"UNKNOWN","code":{},"reason":"some error message"}}"#, + code.as_u16() + ), ); } } #[test] - fn test_extract_base_url() { - let url = - Url::parse("https://example-detector.route.example.com/api/v1/text/contents").unwrap(); - let base_url = extract_base_url(&url); - assert_eq!( - Url::parse("https://example-detector.route.example.com/").unwrap(), - base_url - ); - let health_url = base_url.join("/health").unwrap(); - assert_eq!( - Url::parse("https://example-detector.route.example.com/health").unwrap(), - health_url - ); + fn test_is_valid_hostname() { + let valid_hostnames = ["localhost", "example.route.cloud.com", "127.0.0.1"]; + for hostname in valid_hostnames { + assert!(is_valid_hostname(hostname)); + } + let invalid_hostnames = [ + "-LoCaLhOST_", + ".invalid", + "invalid.ending-.char", + "@asdf", + "too-long-of-a-hostnameeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee", + ]; + for hostname in invalid_hostnames { + assert!(!is_valid_hostname(hostname)); + } } } diff --git a/src/clients/chunker.rs b/src/clients/chunker.rs index 35b04591..bd924b90 100644 --- a/src/clients/chunker.rs +++ b/src/clients/chunker.rs @@ -15,19 +15,22 @@ */ -use std::{collections::HashMap, pin::Pin}; +use std::pin::Pin; +use async_trait::async_trait; +use axum::http::HeaderMap; use futures::{Future, Stream, StreamExt, TryStreamExt}; use ginepro::LoadBalancedChannel; -use tokio::sync::mpsc; -use tokio_stream::wrappers::ReceiverStream; -use tonic::{Request, Response, Status, Streaming}; -use tracing::info; +use tonic::{Code, Request, Response, Status, Streaming}; +use tracing::{info, instrument}; -use super::{create_grpc_clients, BoxStream, Error}; +use super::{ + create_grpc_client, errors::grpc_to_http_code, grpc_request_with_headers, BoxStream, Client, + Error, +}; use crate::{ config::ServiceConfig, - health::{HealthCheckResult, HealthProbe}, + health::{HealthCheckResult, HealthStatus}, pb::{ caikit::runtime::chunkers::{ chunkers_service_client::ChunkersServiceClient, @@ -36,115 +39,104 @@ use crate::{ caikit_data_model::nlp::{ChunkerTokenizationStreamResult, Token, TokenizationResults}, grpc::health::v1::{health_client::HealthClient, HealthCheckRequest}, }, + tracing_utils::trace_context_from_grpc_response, }; +const DEFAULT_PORT: u16 = 8085; const MODEL_ID_HEADER_NAME: &str = "mm-model-id"; /// Default chunker that returns span for entire text -pub const DEFAULT_MODEL_ID: &str = "whole_doc_chunker"; +pub const DEFAULT_CHUNKER_ID: &str = "whole_doc_chunker"; type StreamingTokenizationResult = Result>, Status>; -#[cfg_attr(test, faux::create, derive(Default))] +#[cfg_attr(test, faux::create)] #[derive(Clone)] pub struct ChunkerClient { - clients: HashMap>, - health_clients: HashMap>, -} - -#[cfg_attr(test, faux::methods)] -impl HealthProbe for ChunkerClient { - async fn health(&self) -> Result, Error> { - let mut results = HashMap::with_capacity(self.health_clients.len()); - for (model_id, mut client) in self.health_clients.clone() { - results.insert( - model_id.clone(), - client - .check(HealthCheckRequest { - service: "".to_string(), - }) // Caikit does not expect a service_id to be specified - .await - .into(), - ); - } - Ok(results) - } + client: ChunkersServiceClient, + health_client: HealthClient, } #[cfg_attr(test, faux::methods)] impl ChunkerClient { - pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self { - let clients = create_grpc_clients(default_port, config, ChunkersServiceClient::new).await; - let health_clients = create_grpc_clients(default_port, config, HealthClient::new).await; + pub async fn new(config: &ServiceConfig) -> Self { + let client = create_grpc_client(DEFAULT_PORT, config, ChunkersServiceClient::new).await; + let health_client = create_grpc_client(DEFAULT_PORT, config, HealthClient::new).await; Self { - clients, - health_clients, + client, + health_client, } } - fn client(&self, model_id: &str) -> Result, Error> { - Ok(self - .clients - .get(model_id) - .ok_or_else(|| Error::ModelNotFound { - model_id: model_id.to_string(), - })? - .clone()) - } - + #[instrument(skip_all, fields(model_id))] pub async fn tokenization_task_predict( &self, model_id: &str, request: ChunkerTokenizationTaskRequest, ) -> Result { - // Handle "default" separately first - if model_id == DEFAULT_MODEL_ID { - info!("Using default whole doc chunker"); - return Ok(tokenize_whole_doc(request)); - } - let request = request_with_model_id(request, model_id); - Ok(self - .client(model_id)? - .chunker_tokenization_task_predict(request) - .await? - .into_inner()) + let mut client = self.client.clone(); + let request = request_with_headers(request, model_id); + info!(?request, "sending client request"); + let response = client.chunker_tokenization_task_predict(request).await?; + trace_context_from_grpc_response(&response); + Ok(response.into_inner()) } + #[instrument(skip_all, fields(model_id))] pub async fn bidi_streaming_tokenization_task_predict( &self, model_id: &str, request_stream: BoxStream, ) -> Result>, Error> { - let response_stream = if model_id == DEFAULT_MODEL_ID { - info!("Using default whole doc chunker"); - let (response_tx, response_rx) = mpsc::channel(1); - // Spawn task to collect input stream - tokio::spawn(async move { - // NOTE: this will not resolve until the input stream is closed - let response = tokenize_whole_doc_stream(request_stream).await; - let _ = response_tx.send(response).await; - }); - ReceiverStream::new(response_rx).boxed() + info!("sending client stream request"); + let mut client = self.client.clone(); + let request = request_with_headers(request_stream, model_id); + // NOTE: this is an ugly workaround to avoid bogus higher-ranked lifetime errors. + // https://github.com/rust-lang/rust/issues/110338 + let response_stream_fut: Pin + Send>> = + Box::pin(client.bidi_streaming_chunker_tokenization_task_predict(request)); + let response_stream = response_stream_fut.await?; + trace_context_from_grpc_response(&response_stream); + Ok(response_stream.into_inner().map_err(Into::into).boxed()) + } +} + +#[cfg_attr(test, faux::methods)] +#[async_trait] +impl Client for ChunkerClient { + fn name(&self) -> &str { + "chunker" + } + + async fn health(&self) -> HealthCheckResult { + let mut client = self.health_client.clone(); + let response = client + .check(HealthCheckRequest { service: "".into() }) + .await; + let code = match response { + Ok(_) => Code::Ok, + Err(status) if matches!(status.code(), Code::InvalidArgument | Code::NotFound) => { + Code::Ok + } + Err(status) => status.code(), + }; + let status = if matches!(code, Code::Ok) { + HealthStatus::Healthy } else { - let mut client = self.client(model_id)?; - let request = request_with_model_id(request_stream, model_id); - // NOTE: this is an ugly workaround to avoid bogus higher-ranked lifetime errors. - // https://github.com/rust-lang/rust/issues/110338 - let response_stream_fut: Pin< - Box + Send>, - > = Box::pin(client.bidi_streaming_chunker_tokenization_task_predict(request)); - response_stream_fut - .await? - .into_inner() - .map_err(Into::into) - .boxed() + HealthStatus::Unhealthy }; - Ok(response_stream) + HealthCheckResult { + status, + code: grpc_to_http_code(code), + reason: None, + } } } -fn request_with_model_id(request: T, model_id: &str) -> Request { - let mut request = Request::new(request); +/// Turns a chunker client gRPC request body of type `T` into a `tonic::Request` with headers. +/// Adds the provided `model_id` as a header as well as injects `traceparent` from the current span. +fn request_with_headers(request: T, model_id: &str) -> Request { + let mut request = grpc_request_with_headers(request, HeaderMap::new()); request .metadata_mut() .insert(MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); @@ -152,7 +144,8 @@ fn request_with_model_id(request: T, model_id: &str) -> Request { } /// Unary tokenization result of the entire doc -fn tokenize_whole_doc(request: ChunkerTokenizationTaskRequest) -> TokenizationResults { +#[instrument(skip_all)] +pub fn tokenize_whole_doc(request: ChunkerTokenizationTaskRequest) -> TokenizationResults { let codepoint_count = request.text.chars().count() as i64; TokenizationResults { results: vec![Token { @@ -165,7 +158,8 @@ fn tokenize_whole_doc(request: ChunkerTokenizationTaskRequest) -> TokenizationRe } /// Streaming tokenization result for the entire doc stream -async fn tokenize_whole_doc_stream( +#[instrument(skip_all)] +pub async fn tokenize_whole_doc_stream( request: impl Stream, ) -> Result { let (text, index_vec): (String, Vec) = request diff --git a/src/clients/detector.rs b/src/clients/detector.rs index 455612bc..cf06ddef 100644 --- a/src/clients/detector.rs +++ b/src/clients/detector.rs @@ -15,234 +15,30 @@ */ -use std::collections::HashMap; +use std::fmt::Debug; -use hyper::{HeaderMap, StatusCode}; +use axum::http::HeaderMap; +use hyper::StatusCode; +use reqwest::Response; use serde::{Deserialize, Serialize}; - -use super::{create_http_clients, Error, HttpClient}; -use crate::{ - config::ServiceConfig, - health::{HealthCheck, HealthCheckResult, HealthProbe}, - models::{DetectionResult, DetectorParams}, -}; - +use tracing::info; +use url::Url; + +pub mod text_contents; +pub use text_contents::*; +pub mod text_chat; +pub use text_chat::*; +pub mod text_context_doc; +pub use text_context_doc::*; +pub mod text_generation; +pub use text_generation::*; + +use super::{Error, HttpClient}; +use crate::tracing_utils::{trace_context_from_http_response, with_traceparent_header}; + +const DEFAULT_PORT: u16 = 8080; const DETECTOR_ID_HEADER_NAME: &str = "detector-id"; -// For some reason the order matters here. #[cfg_attr(test, derive(Default), faux::create)] doesn't work. (rustc --explain E0560) -#[cfg_attr(test, faux::create, derive(Default))] -#[derive(Clone)] -pub struct DetectorClient { - clients: HashMap, -} - -#[cfg_attr(test, faux::methods)] -impl HealthProbe for DetectorClient { - async fn health(&self) -> Result, Error> { - let mut results = HashMap::with_capacity(self.clients.len()); - for (model_id, client) in self.clients() { - results.insert(model_id.to_string(), client.check().await); - } - Ok(results) - } -} - -#[cfg_attr(test, faux::methods)] -impl DetectorClient { - pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self { - let clients: HashMap = create_http_clients(default_port, config).await; - Self { clients } - } - - fn client(&self, model_id: &str) -> Result { - Ok(self - .clients - .get(model_id) - .ok_or_else(|| Error::ModelNotFound { - model_id: model_id.to_string(), - })? - .clone()) - } - - fn clients(&self) -> impl Iterator { - self.clients.iter() - } - - // TODO: Use generics here, since the only thing that changes in comparison to generation_detection() - // is the "request" parameter and return types? - /// Invokes detectors implemented with the `/api/v1/text/contents` endpoint - pub async fn text_contents( - &self, - model_id: &str, - request: ContentAnalysisRequest, - headers: HeaderMap, - ) -> Result>, Error> { - let client = self.client(model_id)?; - let url = client.base_url().as_str(); - let response = client - .post(url) - .headers(headers) - .header(DETECTOR_ID_HEADER_NAME, model_id) - .json(&request) - .send() - .await?; - if response.status() == StatusCode::OK { - Ok(response.json().await?) - } else { - let code = response.status().as_u16(); - let error = response - .json::() - .await - .unwrap_or(DetectorError { - code, - message: "".into(), - }); - Err(error.into()) - } - } - - /// Invokes detectors implemented with the `/api/v1/text/generation` endpoint - pub async fn text_generation( - &self, - model_id: &str, - request: GenerationDetectionRequest, - headers: HeaderMap, - ) -> Result, Error> { - let client = self.client(model_id)?; - let url = client.base_url().as_str(); - let response = client - .post(url) - .headers(headers) - .header(DETECTOR_ID_HEADER_NAME, model_id) - .json(&request) - .send() - .await?; - if response.status() == StatusCode::OK { - Ok(response.json().await?) - } else { - let code = response.status().as_u16(); - let error = response - .json::() - .await - .unwrap_or(DetectorError { - code, - message: "".into(), - }); - Err(error.into()) - } - } - - /// Invokes detectors implemented with the `/api/v1/text/context/doc` endpoint - pub async fn text_context_doc( - &self, - model_id: &str, - request: ContextDocsDetectionRequest, - headers: HeaderMap, - ) -> Result, Error> { - let client = self.client(model_id)?; - let url = client.base_url().as_str(); - let response = client - .post(url) - .headers(headers) - .header(DETECTOR_ID_HEADER_NAME, model_id) - .json(&request) - .send() - .await?; - if response.status() == StatusCode::OK { - Ok(response.json().await?) - } else { - let code = response.status().as_u16(); - let error = response - .json::() - .await - .unwrap_or(DetectorError { - code, - message: "".into(), - }); - Err(error.into()) - } - } -} - -/// Request for text content analysis -/// Results of this request will contain analysis / detection of each of the provided documents -/// in the order they are present in the `contents` object. -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct ContentAnalysisRequest { - /// Field allowing users to provide list of documents for analysis - pub contents: Vec, -} - -impl ContentAnalysisRequest { - pub fn new(contents: Vec) -> ContentAnalysisRequest { - ContentAnalysisRequest { contents } - } -} - -/// Evidence -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct Evidence { - /// Evidence name - pub name: String, - /// Optional, evidence value - #[serde(skip_serializing_if = "Option::is_none")] - pub value: Option, - /// Optional, score for evidence - #[serde(skip_serializing_if = "Option::is_none")] - pub score: Option, -} - -/// Evidence in response -#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] -pub struct EvidenceObj { - /// Evidence name - pub name: String, - /// Optional, evidence value - #[serde(skip_serializing_if = "Option::is_none")] - pub value: Option, - /// Optional, score for evidence - #[serde(skip_serializing_if = "Option::is_none")] - pub score: Option, - /// Optional, evidence on evidence value - // Evidence nesting should likely not go beyond this - #[serde(skip_serializing_if = "Option::is_none")] - pub evidence: Option>, -} - -/// Response of text content analysis endpoint -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct ContentAnalysisResponse { - /// Start index of detection - pub start: usize, - /// End index of detection - pub end: usize, - /// Text corresponding to detection - pub text: String, - /// Relevant detection class - pub detection: String, - /// Detection type or aggregate detection label - pub detection_type: String, - /// Score of detection - pub score: f64, - /// Optional, any applicable evidence for detection - #[serde(skip_serializing_if = "Option::is_none")] - pub evidence: Option>, -} - -impl From for crate::models::TokenClassificationResult { - fn from(value: ContentAnalysisResponse) -> Self { - Self { - start: value.start as u32, - end: value.end as u32, - word: value.text, - entity: value.detection, - entity_group: value.detection_type, - score: value.score, - token_count: None, - } - } -} - #[derive(Debug, Clone, Deserialize)] pub struct DetectorError { pub code: u16, @@ -258,66 +54,24 @@ impl From for Error { } } -/// A struct representing a request to a detector compatible with the -/// /api/v1/text/generation endpoint. -#[cfg_attr(test, derive(PartialEq))] -#[derive(Debug, Serialize)] -pub struct GenerationDetectionRequest { - /// User prompt sent to LLM - pub prompt: String, - - /// Text generated from an LLM - pub generated_text: String, -} - -impl GenerationDetectionRequest { - pub fn new(prompt: String, generated_text: String) -> Self { - Self { - prompt, - generated_text, - } - } -} - -/// A struct representing a request to a detector compatible with the -/// /api/v1/text/context/doc endpoint. -#[cfg_attr(test, derive(PartialEq))] -#[derive(Debug, Serialize)] -pub struct ContextDocsDetectionRequest { - /// Content to run detection on - pub content: String, - - /// Type of context being sent - pub context_type: ContextType, - - /// Context to run detection on - pub context: Vec, - - // Detector Params - pub detector_params: DetectorParams, -} - -/// Enum representing the context type of a detection -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub enum ContextType { - #[serde(rename = "docs")] - Document, - #[serde(rename = "url")] - Url, -} - -impl ContextDocsDetectionRequest { - pub fn new( - content: String, - context_type: ContextType, - context: Vec, - detector_params: DetectorParams, - ) -> Self { - Self { - content, - context_type, - context, - detector_params, - } - } +/// Make a POST request for an HTTP detector client and return the response. +/// Also injects the `traceparent` header from the current span and traces the response. +pub async fn post_with_headers( + client: HttpClient, + url: Url, + request: T, + headers: HeaderMap, + model_id: &str, +) -> Result { + let headers = with_traceparent_header(headers); + info!(?url, ?headers, ?request, "sending client request"); + let response = client + .post(url) + .headers(headers) + .header(DETECTOR_ID_HEADER_NAME, model_id) + .json(&request) + .send() + .await?; + trace_context_from_http_response(&response); + Ok(response) } diff --git a/src/clients/detector/text_chat.rs b/src/clients/detector/text_chat.rs new file mode 100644 index 00000000..bb6fc524 --- /dev/null +++ b/src/clients/detector/text_chat.rs @@ -0,0 +1,125 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +use async_trait::async_trait; +use hyper::{HeaderMap, StatusCode}; +use serde::Serialize; +use tracing::{debug, info, instrument}; + +use super::{DetectorError, DEFAULT_PORT, DETECTOR_ID_HEADER_NAME}; +use crate::{ + clients::{create_http_client, openai::Message, Client, Error, HttpClient}, + config::ServiceConfig, + health::HealthCheckResult, + models::{DetectionResult, DetectorParams}, +}; + +const CHAT_DETECTOR_ENDPOINT: &str = "/api/v1/text/chat"; + +#[cfg_attr(test, faux::create)] +#[derive(Clone)] +pub struct TextChatDetectorClient { + client: HttpClient, + health_client: Option, +} + +#[cfg_attr(test, faux::methods)] +impl TextChatDetectorClient { + pub async fn new(config: &ServiceConfig, health_config: Option<&ServiceConfig>) -> Self { + let client = create_http_client(DEFAULT_PORT, config).await; + let health_client = if let Some(health_config) = health_config { + Some(create_http_client(DEFAULT_PORT, health_config).await) + } else { + None + }; + Self { + client, + health_client, + } + } + + #[instrument(skip_all, fields(model_id, ?headers))] + pub async fn text_chat( + &self, + model_id: &str, + request: ChatDetectionRequest, + headers: HeaderMap, + ) -> Result, Error> { + let url = self.client.base_url().join(CHAT_DETECTOR_ENDPOINT).unwrap(); + info!(?url, "sending chat detector client request"); + let request = self + .client + .post(url) + .headers(headers) + .header(DETECTOR_ID_HEADER_NAME, model_id) + .json(&request); + debug!("chat detector client request: {:?}", request); + let response = request.send().await?; + debug!("chat detector client response: {:?}", response); + + if response.status() == StatusCode::OK { + Ok(response.json().await?) + } else { + let code = response.status().as_u16(); + let error = response + .json::() + .await + .unwrap_or(DetectorError { + code, + message: "".into(), + }); + Err(error.into()) + } + } +} + +#[cfg_attr(test, faux::methods)] +#[async_trait] +impl Client for TextChatDetectorClient { + fn name(&self) -> &str { + "text_chat_detector" + } + + async fn health(&self) -> HealthCheckResult { + if let Some(health_client) = &self.health_client { + health_client.health().await + } else { + self.client.health().await + } + } +} + +/// A struct representing a request to a detector compatible with the +/// /api/v1/text/chat endpoint. +// #[cfg_attr(test, derive(PartialEq))] +#[derive(Debug, Serialize)] +pub struct ChatDetectionRequest { + /// Chat messages to run detection on + pub messages: Vec, + + /// Detector parameters (available parameters depend on the detector) + pub detector_params: DetectorParams, +} + +impl ChatDetectionRequest { + pub fn new(messages: Vec, detector_params: DetectorParams) -> Self { + Self { + messages, + detector_params, + } + } +} diff --git a/src/clients/detector/text_contents.rs b/src/clients/detector/text_contents.rs new file mode 100644 index 00000000..d1015b1d --- /dev/null +++ b/src/clients/detector/text_contents.rs @@ -0,0 +1,182 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +use async_trait::async_trait; +use hyper::{HeaderMap, StatusCode}; +use serde::{Deserialize, Serialize}; +use tracing::instrument; + +use super::{post_with_headers, DetectorError, DEFAULT_PORT}; +use crate::{ + clients::{create_http_client, Client, Error, HttpClient}, + config::ServiceConfig, + health::HealthCheckResult, + models::DetectorParams, +}; + +#[cfg_attr(test, faux::create)] +#[derive(Clone)] +pub struct TextContentsDetectorClient { + client: HttpClient, + health_client: Option, +} + +#[cfg_attr(test, faux::methods)] +impl TextContentsDetectorClient { + pub async fn new(config: &ServiceConfig, health_config: Option<&ServiceConfig>) -> Self { + let client = create_http_client(DEFAULT_PORT, config).await; + let health_client = if let Some(health_config) = health_config { + Some(create_http_client(DEFAULT_PORT, health_config).await) + } else { + None + }; + Self { + client, + health_client, + } + } + + #[instrument(skip_all, fields(model_id))] + pub async fn text_contents( + &self, + model_id: &str, + request: ContentAnalysisRequest, + headers: HeaderMap, + ) -> Result>, Error> { + let url = self + .client + .base_url() + .join("/api/v1/text/contents") + .unwrap(); + let response = + post_with_headers(self.client.clone(), url, request, headers, model_id).await?; + if response.status() == StatusCode::OK { + Ok(response.json().await?) + } else { + let code = response.status().as_u16(); + let error = response + .json::() + .await + .unwrap_or(DetectorError { + code, + message: "".into(), + }); + Err(error.into()) + } + } +} + +#[cfg_attr(test, faux::methods)] +#[async_trait] +impl Client for TextContentsDetectorClient { + fn name(&self) -> &str { + "text_contents_detector" + } + + async fn health(&self) -> HealthCheckResult { + if let Some(health_client) = &self.health_client { + health_client.health().await + } else { + self.client.health().await + } + } +} + +/// Request for text content analysis +/// Results of this request will contain analysis / detection of each of the provided documents +/// in the order they are present in the `contents` object. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct ContentAnalysisRequest { + /// Field allowing users to provide list of documents for analysis + pub contents: Vec, + + /// Detector parameters (available parameters depend on the detector) + pub detector_params: DetectorParams, +} + +impl ContentAnalysisRequest { + pub fn new(contents: Vec, detector_params: DetectorParams) -> ContentAnalysisRequest { + ContentAnalysisRequest { + contents, + detector_params, + } + } +} + +/// Response of text content analysis endpoint +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct ContentAnalysisResponse { + /// Start index of detection + pub start: usize, + /// End index of detection + pub end: usize, + /// Text corresponding to detection + pub text: String, + /// Relevant detection class + pub detection: String, + /// Detection type or aggregate detection label + pub detection_type: String, + /// Score of detection + pub score: f64, + /// Optional, any applicable evidence for detection + #[serde(skip_serializing_if = "Option::is_none")] + pub evidence: Option>, +} + +impl From for crate::models::TokenClassificationResult { + fn from(value: ContentAnalysisResponse) -> Self { + Self { + start: value.start as u32, + end: value.end as u32, + word: value.text, + entity: value.detection, + entity_group: value.detection_type, + score: value.score, + token_count: None, + } + } +} + +/// Evidence +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct Evidence { + /// Evidence name + pub name: String, + /// Optional, evidence value + #[serde(skip_serializing_if = "Option::is_none")] + pub value: Option, + /// Optional, score for evidence + #[serde(skip_serializing_if = "Option::is_none")] + pub score: Option, +} + +/// Evidence in response +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub struct EvidenceObj { + /// Evidence name + pub name: String, + /// Optional, evidence value + #[serde(skip_serializing_if = "Option::is_none")] + pub value: Option, + /// Optional, score for evidence + #[serde(skip_serializing_if = "Option::is_none")] + pub score: Option, + /// Optional, evidence on evidence value + // Evidence nesting should likely not go beyond this + #[serde(skip_serializing_if = "Option::is_none")] + pub evidence: Option>, +} diff --git a/src/clients/detector/text_context_doc.rs b/src/clients/detector/text_context_doc.rs new file mode 100644 index 00000000..e56c1619 --- /dev/null +++ b/src/clients/detector/text_context_doc.rs @@ -0,0 +1,140 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +use async_trait::async_trait; +use hyper::{HeaderMap, StatusCode}; +use serde::{Deserialize, Serialize}; +use tracing::instrument; + +use super::{post_with_headers, DetectorError, DEFAULT_PORT}; +use crate::{ + clients::{create_http_client, Client, Error, HttpClient}, + config::ServiceConfig, + health::HealthCheckResult, + models::{DetectionResult, DetectorParams}, +}; + +#[cfg_attr(test, faux::create)] +#[derive(Clone)] +pub struct TextContextDocDetectorClient { + client: HttpClient, + health_client: Option, +} + +#[cfg_attr(test, faux::methods)] +impl TextContextDocDetectorClient { + pub async fn new(config: &ServiceConfig, health_config: Option<&ServiceConfig>) -> Self { + let client = create_http_client(DEFAULT_PORT, config).await; + let health_client = if let Some(health_config) = health_config { + Some(create_http_client(DEFAULT_PORT, health_config).await) + } else { + None + }; + Self { + client, + health_client, + } + } + + #[instrument(skip_all, fields(model_id))] + pub async fn text_context_doc( + &self, + model_id: &str, + request: ContextDocsDetectionRequest, + headers: HeaderMap, + ) -> Result, Error> { + let url = self + .client + .base_url() + .join("/api/v1/text/context/doc") + .unwrap(); + let response = + post_with_headers(self.client.clone(), url, request, headers, model_id).await?; + if response.status() == StatusCode::OK { + Ok(response.json().await?) + } else { + let code = response.status().as_u16(); + let error = response + .json::() + .await + .unwrap_or(DetectorError { + code, + message: "".into(), + }); + Err(error.into()) + } + } +} + +#[cfg_attr(test, faux::methods)] +#[async_trait] +impl Client for TextContextDocDetectorClient { + fn name(&self) -> &str { + "text_context_doc_detector" + } + + async fn health(&self) -> HealthCheckResult { + if let Some(health_client) = &self.health_client { + health_client.health().await + } else { + self.client.health().await + } + } +} + +/// A struct representing a request to a detector compatible with the +/// /api/v1/text/context/doc endpoint. +#[cfg_attr(test, derive(PartialEq))] +#[derive(Debug, Serialize)] +pub struct ContextDocsDetectionRequest { + /// Content to run detection on + pub content: String, + + /// Type of context being sent + pub context_type: ContextType, + + /// Context to run detection on + pub context: Vec, + + /// Detector parameters (available parameters depend on the detector) + pub detector_params: DetectorParams, +} + +/// Enum representing the context type of a detection +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub enum ContextType { + #[serde(rename = "docs")] + Document, + #[serde(rename = "url")] + Url, +} + +impl ContextDocsDetectionRequest { + pub fn new( + content: String, + context_type: ContextType, + context: Vec, + detector_params: DetectorParams, + ) -> Self { + Self { + content, + context_type, + context, + detector_params, + } + } +} diff --git a/src/clients/detector/text_generation.rs b/src/clients/detector/text_generation.rs new file mode 100644 index 00000000..5a63d6c9 --- /dev/null +++ b/src/clients/detector/text_generation.rs @@ -0,0 +1,122 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +use async_trait::async_trait; +use hyper::{HeaderMap, StatusCode}; +use serde::Serialize; +use tracing::instrument; + +use super::{post_with_headers, DetectorError, DEFAULT_PORT}; +use crate::{ + clients::{create_http_client, Client, Error, HttpClient}, + config::ServiceConfig, + health::HealthCheckResult, + models::{DetectionResult, DetectorParams}, +}; + +#[cfg_attr(test, faux::create)] +#[derive(Clone)] +pub struct TextGenerationDetectorClient { + client: HttpClient, + health_client: Option, +} + +#[cfg_attr(test, faux::methods)] +impl TextGenerationDetectorClient { + pub async fn new(config: &ServiceConfig, health_config: Option<&ServiceConfig>) -> Self { + let client = create_http_client(DEFAULT_PORT, config).await; + let health_client = if let Some(health_config) = health_config { + Some(create_http_client(DEFAULT_PORT, health_config).await) + } else { + None + }; + Self { + client, + health_client, + } + } + + #[instrument(skip_all, fields(model_id))] + pub async fn text_generation( + &self, + model_id: &str, + request: GenerationDetectionRequest, + headers: HeaderMap, + ) -> Result, Error> { + let url = self + .client + .base_url() + .join("/api/v1/text/generation") + .unwrap(); + let response = + post_with_headers(self.client.clone(), url, request, headers, model_id).await?; + if response.status() == StatusCode::OK { + Ok(response.json().await?) + } else { + let code = response.status().as_u16(); + let error = response + .json::() + .await + .unwrap_or(DetectorError { + code, + message: "".into(), + }); + Err(error.into()) + } + } +} + +#[cfg_attr(test, faux::methods)] +#[async_trait] +impl Client for TextGenerationDetectorClient { + fn name(&self) -> &str { + "text_context_doc_detector" + } + + async fn health(&self) -> HealthCheckResult { + if let Some(health_client) = &self.health_client { + health_client.health().await + } else { + self.client.health().await + } + } +} + +/// A struct representing a request to a detector compatible with the +/// /api/v1/text/generation endpoint. +#[cfg_attr(test, derive(PartialEq))] +#[derive(Debug, Serialize)] +pub struct GenerationDetectionRequest { + /// User prompt sent to LLM + pub prompt: String, + + /// Text generated from an LLM + pub generated_text: String, + + /// Detector parameters (available parameters depend on the detector) + pub detector_params: DetectorParams, +} + +impl GenerationDetectionRequest { + pub fn new(prompt: String, generated_text: String, detector_params: DetectorParams) -> Self { + Self { + prompt, + generated_text, + detector_params, + } + } +} diff --git a/src/clients/errors.rs b/src/clients/errors.rs new file mode 100644 index 00000000..e630e21a --- /dev/null +++ b/src/clients/errors.rs @@ -0,0 +1,96 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +use std::error::Error as _; + +use hyper::StatusCode; +use tracing::error; + +/// Client errors. +#[derive(Debug, Clone, PartialEq, thiserror::Error)] +pub enum Error { + #[error("{}", .message)] + Grpc { code: StatusCode, message: String }, + #[error("{}", .message)] + Http { code: StatusCode, message: String }, + #[error("model not found: {model_id}")] + ModelNotFound { model_id: String }, +} + +impl Error { + /// Returns status code. + pub fn status_code(&self) -> StatusCode { + match self { + // Return equivalent http status code for grpc status code + Error::Grpc { code, .. } => *code, + // Return http status code for error responses + // and 500 for other errors + Error::Http { code, .. } => *code, + // Return 404 for model not found + Error::ModelNotFound { .. } => StatusCode::NOT_FOUND, + } + } +} + +impl From for Error { + fn from(value: reqwest::Error) -> Self { + // Log lower level source of error. + // Examples: + // 1. client error (Connect) // Cases like connection error, wrong port etc. + // 2. client error (SendRequest) // Cases like cert issues + error!( + "http request failed. Source: {}", + value.source().unwrap().to_string() + ); + // Return http status code for error responses + // and 500 for other errors + let code = match value.status() { + Some(code) => code, + None => StatusCode::INTERNAL_SERVER_ERROR, + }; + Self::Http { + code, + message: value.to_string(), + } + } +} + +impl From for Error { + fn from(value: tonic::Status) -> Self { + Self::Grpc { + code: grpc_to_http_code(value.code()), + message: value.message().to_string(), + } + } +} + +/// Returns equivalent http status code for grpc status code +pub fn grpc_to_http_code(value: tonic::Code) -> StatusCode { + use tonic::Code::*; + match value { + InvalidArgument => StatusCode::BAD_REQUEST, + Internal => StatusCode::INTERNAL_SERVER_ERROR, + NotFound => StatusCode::NOT_FOUND, + DeadlineExceeded => StatusCode::REQUEST_TIMEOUT, + Unimplemented => StatusCode::NOT_IMPLEMENTED, + Unauthenticated => StatusCode::UNAUTHORIZED, + PermissionDenied => StatusCode::FORBIDDEN, + Unavailable => StatusCode::SERVICE_UNAVAILABLE, + Ok => StatusCode::OK, + _ => StatusCode::INTERNAL_SERVER_ERROR, + } +} diff --git a/src/clients/generation.rs b/src/clients/generation.rs index d599d8c1..c5068dc1 100644 --- a/src/clients/generation.rs +++ b/src/clients/generation.rs @@ -15,15 +15,14 @@ */ -use std::collections::HashMap; - +use async_trait::async_trait; use futures::{StreamExt, TryStreamExt}; use hyper::HeaderMap; -use tracing::debug; +use tracing::{debug, instrument}; -use super::{BoxStream, Error, NlpClient, TgisClient}; +use super::{BoxStream, Client, Error, NlpClient, TgisClient, NlpClientHttp}; use crate::{ - health::{HealthCheckResult, HealthProbe}, + health::HealthCheckResult, models::{ ClassifiedGeneratedTextResult, ClassifiedGeneratedTextStreamResult, GuardrailsTextGenerationParameters, @@ -40,7 +39,7 @@ use crate::{ }, }; -#[cfg_attr(test, faux::create, derive(Default))] +#[cfg_attr(test, faux::create)] #[derive(Clone)] pub struct GenerationClient(Option); @@ -48,24 +47,7 @@ pub struct GenerationClient(Option); enum GenerationClientInner { Tgis(TgisClient), Nlp(NlpClient), -} - -#[cfg_attr(test, faux::methods)] -impl HealthProbe for GenerationClient { - async fn health(&self) -> Result, Error> { - match &self.0 { - Some(GenerationClientInner::Tgis(client)) => client.health().await, - Some(GenerationClientInner::Nlp(client)) => client.health().await, - None => Ok(HashMap::new()), - } - } -} - -#[cfg(test)] -impl Default for GenerationClientInner { - fn default() -> Self { - Self::Tgis(TgisClient::default()) - } + NlpHttp(NlpHttpClient), } #[cfg_attr(test, faux::methods)] @@ -78,10 +60,15 @@ impl GenerationClient { Self(Some(GenerationClientInner::Nlp(client))) } + pub fn nlp_http(client: NlpClientHttp) -> Self { + Self(Some(GenerationClientInner::NlpHttp(client))) + } + pub fn not_configured() -> Self { Self(None) } + #[instrument(skip_all, fields(model_id))] pub async fn tokenize( &self, model_id: String, @@ -97,19 +84,33 @@ impl GenerationClient { return_offsets: false, truncate_input_tokens: 0, }; - debug!(%model_id, provider = "tgis", ?request, "sending tokenize request"); + debug!(provider = "tgis", ?request, "sending tokenize request"); let mut response = client.tokenize(request, headers).await?; - debug!(%model_id, provider = "tgis", ?response, "received tokenize response"); + debug!(provider = "tgis", ?response, "received tokenize response"); let response = response.responses.swap_remove(0); Ok((response.token_count, response.tokens)) } Some(GenerationClientInner::Nlp(client)) => { let request = TokenizationTaskRequest { text }; - debug!(%model_id, provider = "nlp", ?request, "sending tokenize request"); + debug!(provider = "nlp", ?request, "sending tokenize request"); + let response = client + .tokenization_task_predict(&model_id, request, headers) + .await?; + debug!(provider = "nlp", ?response, "received tokenize response"); + let tokens = response + .results + .into_iter() + .map(|token| token.text) + .collect::>(); + Ok((response.token_count as u32, tokens)) + } + Some(GenerationClientInner::NlpHttp(client)) => { + let request = TokenizationTaskRequest { text }; + debug!(provider = "nlp-http", ?request, "sending tokenize request"); let response = client .tokenization_task_predict(&model_id, request, headers) .await?; - debug!(%model_id, provider = "nlp", ?response, "received tokenize response"); + debug!(provider = "nlp-http", ?response, "received tokenize response"); let tokens = response .results .into_iter() @@ -121,6 +122,7 @@ impl GenerationClient { } } + #[instrument(skip_all, fields(model_id))] pub async fn generate( &self, model_id: String, @@ -137,9 +139,9 @@ impl GenerationClient { requests: vec![GenerationRequest { text }], params, }; - debug!(%model_id, provider = "tgis", ?request, "sending generate request"); + debug!(provider = "tgis", ?request, "sending generate request"); let response = client.generate(request, headers).await?; - debug!(%model_id, provider = "tgis", ?response, "received generate response"); + debug!(provider = "tgis", ?response, "received generate response"); Ok(response.into()) } Some(GenerationClientInner::Nlp(client)) => { @@ -174,17 +176,58 @@ impl GenerationClient { ..Default::default() } }; - debug!(%model_id, provider = "nlp", ?request, "sending generate request"); + debug!(provider = "nlp", ?request, "sending generate request"); let response = client .text_generation_task_predict(&model_id, request, headers) .await?; - debug!(%model_id, provider = "nlp", ?response, "received generate response"); + debug!(provider = "nlp", ?response, "received generate response"); Ok(response.into()) } + Some(GenerationClientInner::NlpHttp(client)) => { + let request = if let Some(params) = params { + TextGenerationTaskRequest { + text, + max_new_tokens: params.max_new_tokens.map(|v| v as i64), + min_new_tokens: params.min_new_tokens.map(|v| v as i64), + truncate_input_tokens: params.truncate_input_tokens.map(|v| v as i64), + decoding_method: params.decoding_method, + top_k: params.top_k.map(|v| v as i64), + top_p: params.top_p, + typical_p: params.typical_p, + temperature: params.temperature, + repetition_penalty: params.repetition_penalty, + max_time: params.max_time, + exponential_decay_length_penalty: params + .exponential_decay_length_penalty + .map(Into::into), + stop_sequences: params.stop_sequences.unwrap_or_default(), + seed: params.seed.map(|v| v as u64), + preserve_input_text: params.preserve_input_text, + input_tokens: params.input_tokens, + generated_tokens: params.generated_tokens, + token_logprobs: params.token_logprobs, + token_ranks: params.token_ranks, + include_stop_sequence: params.include_stop_sequence, + } + } else { + TextGenerationTaskRequest { + text, + ..Default::default() + } + } + debug!(provider = "nlp-http", ?request, "sending generate request"); + let response = client + .text_generation_task_predict(&model_id, request, headers) + .await?; + debug!(provider = "nlp-http", ?response, "received generate response"); + Ok(response.into()) + } + }; None => Err(Error::ModelNotFound { model_id }), } } + #[instrument(skip_all, fields(model_id))] pub async fn generate_stream( &self, model_id: String, @@ -201,7 +244,11 @@ impl GenerationClient { request: Some(GenerationRequest { text }), params, }; - debug!(%model_id, provider = "tgis", ?request, "sending generate_stream request"); + debug!( + provider = "tgis", + ?request, + "sending generate_stream request" + ); let response_stream = client .generate_stream(request, headers) .await? @@ -241,7 +288,11 @@ impl GenerationClient { ..Default::default() } }; - debug!(%model_id, provider = "nlp", ?request, "sending generate_stream request"); + debug!( + provider = "nlp", + ?request, + "sending generate_stream request" + ); let response_stream = client .server_streaming_text_generation_task_predict(&model_id, request, headers) .await? @@ -249,7 +300,67 @@ impl GenerationClient { .boxed(); Ok(response_stream) } + Some(GenerationClientInner::NlpHttp(client)) => { +` let request = if let Some(params) = params { + ServerStreamingTextGenerationTaskRequest{ + text, + max_new_tokens: params.max_new_tokens.map(|v| v as i64), + min_new_tokens: params.min_new_tokens.map(|v| v as i64), + truncate_input_tokens: params.truncate_input_tokens.map(|v| v as i64), + decoding_method: params.decoding_method, + top_k: params.top_k.map(|v| v as i64), + top_p: params.top_p, + typical_p: params.typical_p, + temperature: params.temperature, + repetition_penalty: params.repetition_penalty, + max_time: params.max_time, + exponential_decay_length_penalty: params + .exponential_decay_length_penalty + .map(Into::into), + stop_sequences: params.stop_sequences.unwrap_or_default(), + seed: params.seed.map(|v| v as u64), + preserve_input_text: params.preserve_input_text, + input_tokens: params.input_tokens, + generated_tokens: params.generated_tokens, + token_logprobs: params.token_logprobs, + token_ranks: params.token_ranks, + include_stop_sequence: params.include_stop_sequence, + } + } else { + ServerStreamingTextGenerationTaskRequest { + text, + ..Default::default() + } + }; + debug!( + provider = "nlp-http", + ?request, + "sending generate_stream request" + ); + let response_stream = client + .server_streaming_text_generation_task_predict(&model_id, request, headers) + .await? + .map_ok(Into::into) + .boxed(); + Ok(response_stream) + } None => Err(Error::ModelNotFound { model_id }), } } + +#[cfg_attr(test, faux::methods)] +#[async_trait] +impl Client for GenerationClient { + fn name(&self) -> &str { + "generation" + } + + async fn health(&self) -> HealthCheckResult { + match &self.0 { + Some(GenerationClientInner::Tgis(client)) => client.health().await, + Some(GenerationClientInner::Nlp(client)) => client.health().await, + Some(GenerationClientInner::NlpHttp(client)) => client.health().await, + None => unimplemented!(), + } + } } diff --git a/src/clients/http.rs b/src/clients/http.rs new file mode 100644 index 00000000..862db811 --- /dev/null +++ b/src/clients/http.rs @@ -0,0 +1,156 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +use hyper::StatusCode; +use reqwest::Response; +use tracing::error; +use url::Url; + +use crate::health::{HealthCheckResult, HealthStatus, OptionalHealthCheckResponseBody}; + +#[derive(Clone)] +pub struct HttpClient { + base_url: Url, + health_url: Url, + client: reqwest::Client, +} + +impl HttpClient { + pub fn new(base_url: Url, client: reqwest::Client) -> Self { + let health_url = base_url.join("health").unwrap(); + Self { + base_url, + health_url, + client, + } + } + + pub fn base_url(&self) -> &Url { + &self.base_url + } + + /// This is sectioned off to allow for testing. + pub(super) async fn http_response_to_health_check_result( + res: Result, + ) -> HealthCheckResult { + match res { + Ok(response) => { + if response.status() == StatusCode::OK { + if let Ok(body) = response.json::().await { + // If the service provided a body, we only anticipate a minimal health status and optional reason. + HealthCheckResult { + status: body.status.clone(), + code: StatusCode::OK, + reason: match body.status { + HealthStatus::Healthy => None, + _ => body.reason, + }, + } + } else { + // If the service did not provide a body, we assume it is healthy. + HealthCheckResult { + status: HealthStatus::Healthy, + code: StatusCode::OK, + reason: None, + } + } + } else { + HealthCheckResult { + // The most we can presume is that 5xx errors are likely indicating service issues, implying the service is unhealthy. + // and that 4xx errors are more likely indicating health check failures, i.e. due to configuration/implementation issues. + // Regardless we can't be certain, so the reason is also provided. + // TODO: We will likely circle back to re-evaluate this logic in the future + // when we know more about how the client health results will be used. + status: if response.status().as_u16() >= 500 + && response.status().as_u16() < 600 + { + HealthStatus::Unhealthy + } else if response.status().as_u16() >= 400 + && response.status().as_u16() < 500 + { + HealthStatus::Unknown + } else { + error!( + "unexpected http health check status code: {}", + response.status() + ); + HealthStatus::Unknown + }, + code: response.status(), + reason: response.status().canonical_reason().map(|v| v.to_string()), + } + } + } + Err(e) => { + error!("error checking health: {}", e); + HealthCheckResult { + status: HealthStatus::Unknown, + code: e.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), + reason: Some(e.to_string()), + } + } + } + } + + pub async fn health(&self) -> HealthCheckResult { + let res = self.get(self.health_url.clone()).send().await; + Self::http_response_to_health_check_result(res).await + } +} + +impl std::ops::Deref for HttpClient { + type Target = reqwest::Client; + + fn deref(&self) -> &Self::Target { + &self.client + } +} + +/// Extracts a base url from a url including path segments. +pub fn extract_base_url(url: &Url) -> Option { + let mut url = url.clone(); + match url.path_segments_mut() { + Ok(mut path) => { + path.clear(); + } + Err(_) => { + return None; + } + } + Some(url) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_base_url() { + let url = + Url::parse("https://example-detector.route.example.com/api/v1/text/contents").unwrap(); + let base_url = extract_base_url(&url); + assert_eq!( + Some(Url::parse("https://example-detector.route.example.com/").unwrap()), + base_url + ); + let health_url = base_url.map(|v| v.join("/health").unwrap()); + assert_eq!( + Some(Url::parse("https://example-detector.route.example.com/health").unwrap()), + health_url + ); + } +} diff --git a/src/clients/nlp.rs b/src/clients/nlp.rs index 90b3f693..01067019 100644 --- a/src/clients/nlp.rs +++ b/src/clients/nlp.rs @@ -15,18 +15,20 @@ */ -use std::collections::HashMap; - -use axum::http::{Extensions, HeaderMap}; +use async_trait::async_trait; +use axum::http::HeaderMap; use futures::{StreamExt, TryStreamExt}; use ginepro::LoadBalancedChannel; -use tonic::{metadata::MetadataMap, Request}; +use tonic::{Code, Request}; +use tracing::{info, instrument}; -use super::{create_grpc_clients, BoxStream, Error}; +use super::{ + create_grpc_client, errors::grpc_to_http_code, grpc_request_with_headers, BoxStream, Client, + Error, +}; use crate::{ - clients::COMMON_ROUTER_KEY, config::ServiceConfig, - health::{HealthCheckResult, HealthProbe}, + health::{HealthCheckResult, HealthStatus}, pb::{ caikit::runtime::nlp::{ nlp_service_client::NlpServiceClient, ServerStreamingTextGenerationTaskRequest, @@ -38,122 +40,130 @@ use crate::{ }, grpc::health::v1::{health_client::HealthClient, HealthCheckRequest}, }, + tracing_utils::trace_context_from_grpc_response, }; +const DEFAULT_PORT: u16 = 8085; const MODEL_ID_HEADER_NAME: &str = "mm-model-id"; -#[cfg_attr(test, faux::create, derive(Default))] +#[cfg_attr(test, faux::create)] #[derive(Clone)] pub struct NlpClient { - clients: HashMap>, - health_clients: HashMap>, -} - -#[cfg_attr(test, faux::methods)] -impl HealthProbe for NlpClient { - async fn health(&self) -> Result, Error> { - let mut results = HashMap::with_capacity(self.health_clients.len()); - for (model_id, mut client) in self.health_clients.clone() { - results.insert( - model_id.clone(), - client - .check(HealthCheckRequest { - service: model_id.clone(), - }) - .await - .into(), - ); - } - Ok(results) - } + client: NlpServiceClient, + health_client: HealthClient, } #[cfg_attr(test, faux::methods)] impl NlpClient { - pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self { - let clients = create_grpc_clients(default_port, config, NlpServiceClient::new).await; - let health_clients = create_grpc_clients(default_port, config, HealthClient::new).await; + pub async fn new(config: &ServiceConfig) -> Self { + let client = create_grpc_client(DEFAULT_PORT, config, NlpServiceClient::new).await; + let health_client = create_grpc_client(DEFAULT_PORT, config, HealthClient::new).await; Self { - clients, - health_clients, + client, + health_client, } } - fn client(&self, _model_id: &str) -> Result, Error> { - // NOTE: We currently forward requests to common router, so we use a single client. - let model_id = COMMON_ROUTER_KEY; - Ok(self - .clients - .get(model_id) - .ok_or_else(|| Error::ModelNotFound { - model_id: model_id.to_string(), - })? - .clone()) - } - + #[instrument(skip_all, fields(model_id))] pub async fn tokenization_task_predict( &self, model_id: &str, request: TokenizationTaskRequest, headers: HeaderMap, ) -> Result { - let request = request_with_model_id(request, model_id, headers); - Ok(self - .client(model_id)? - .tokenization_task_predict(request) - .await? - .into_inner()) + let mut client = self.client.clone(); + let request = request_with_headers(request, model_id, headers); + info!(?request, "sending request to NLP gRPC service"); + let response = client.tokenization_task_predict(request).await?; + trace_context_from_grpc_response(&response); + Ok(response.into_inner()) } + #[instrument(skip_all, fields(model_id))] pub async fn token_classification_task_predict( &self, model_id: &str, request: TokenClassificationTaskRequest, headers: HeaderMap, ) -> Result { - let request = request_with_model_id(request, model_id, headers); - Ok(self - .client(model_id)? - .token_classification_task_predict(request) - .await? - .into_inner()) + let mut client = self.client.clone(); + let request = request_with_headers(request, model_id, headers); + info!(?request, "sending request to NLP gRPC service"); + let response = client.token_classification_task_predict(request).await?; + trace_context_from_grpc_response(&response); + Ok(response.into_inner()) } + #[instrument(skip_all, fields(model_id))] pub async fn text_generation_task_predict( &self, model_id: &str, request: TextGenerationTaskRequest, headers: HeaderMap, ) -> Result { - let request = request_with_model_id(request, model_id, headers); - Ok(self - .client(model_id)? - .text_generation_task_predict(request) - .await? - .into_inner()) + let mut client = self.client.clone(); + let request = request_with_headers(request, model_id, headers); + info!(?request, "sending request to NLP gRPC service"); + let response = client.text_generation_task_predict(request).await?; + trace_context_from_grpc_response(&response); + Ok(response.into_inner()) } + #[instrument(skip_all, fields(model_id))] pub async fn server_streaming_text_generation_task_predict( &self, model_id: &str, request: ServerStreamingTextGenerationTaskRequest, headers: HeaderMap, ) -> Result>, Error> { - let request = request_with_model_id(request, model_id, headers); - let response_stream = self - .client(model_id)? + let mut client = self.client.clone(); + let request = request_with_headers(request, model_id, headers); + info!(?request, "sending stream request to NLP gRPC service"); + let response = client .server_streaming_text_generation_task_predict(request) - .await? - .into_inner() - .map_err(Into::into) - .boxed(); + .await?; + trace_context_from_grpc_response(&response); + let response_stream = response.into_inner().map_err(Into::into).boxed(); Ok(response_stream) } } -fn request_with_model_id(request: T, model_id: &str, headers: HeaderMap) -> Request { - let metadata = MetadataMap::from_headers(headers); - let mut request = Request::from_parts(metadata, Extensions::new(), request); +#[cfg_attr(test, faux::methods)] +#[async_trait] +impl Client for NlpClient { + fn name(&self) -> &str { + "nlp" + } + + async fn health(&self) -> HealthCheckResult { + let mut client = self.health_client.clone(); + let response = client + .check(HealthCheckRequest { service: "".into() }) + .await; + let code = match response { + Ok(_) => Code::Ok, + Err(status) if matches!(status.code(), Code::InvalidArgument | Code::NotFound) => { + Code::Ok + } + Err(status) => status.code(), + }; + let status = if matches!(code, Code::Ok) { + HealthStatus::Healthy + } else { + HealthStatus::Unhealthy + }; + HealthCheckResult { + status, + code: grpc_to_http_code(code), + reason: None, + } + } +} + +/// Turns an NLP client gRPC request body of type `T` and headers into a `tonic::Request`. +/// Also injects provided `model_id` and `traceparent` from current context into headers. +fn request_with_headers(request: T, model_id: &str, headers: HeaderMap) -> Request { + let mut request = grpc_request_with_headers(request, headers); request .metadata_mut() .insert(MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); diff --git a/src/clients/nlp_http.rs b/src/clients/nlp_http.rs new file mode 100644 index 00000000..7b4c3645 --- /dev/null +++ b/src/clients/nlp_http.rs @@ -0,0 +1,184 @@ +use async_trait::async_trait; +use axum::extract::Extension; +use tracing::{info, instrument}; +use hyper::{HeaderMap, StatusCode}; +use tracing::{info, instrument}; + +use super::{ + create_http_client, Client, Error, HttpClient +}; +use crate::{ + config::ServiceConfig, + health::HealthCheckResult, + pb::{ + caikit::runtime::nlp::{ + nlp_service_client::NlpServiceClient, ServerStreamingTextGenerationTaskRequest, + TextGenerationTaskRequest, TokenClassificationTaskRequest, TokenizationTaskRequest, + }, + caikit_data_model::nlp::{ + GeneratedTextResult, GeneratedTextStreamResult, TokenClassificationResults, + TokenizationResults, + }, + }, + tracing_utils::trace_context_from_http_response +}; + +const DEFAULT_PORT: u16 = 8085; +const MODEL_ID_HEADER_NAME: &str = "mm-model-id"; + +#[cfg_attr(test, faux::create)] +#[derive(Clone)] +pub struct NlpClientHttp { + client: HttpClient, + health_client: Option, +} + +#[cfg_attr(test, faux::methods)] +impl NlpClientHttp { + pub async fn new(config: &ServiceConfig) -> Self { + let client = create_http_client(DEFAULT_PORT, config); + let health_client = if let Some(health_config) = health_config { + Some(create_http_client(DEFAULT_PORT, health_config).await); + } else { + None + }; + Self { + client, + health_client, + } + } + + #[instrument(skip_all, fields(request.model))] + pub async tokenization_task_predict( + &self, + request: caikit::runtime::nlp::TokenizationTaskRequest, + headers: HeaderMap, + ) -> Result { + let url = self.client.base_url().join("/api/v1/task/tokenization").unwrap(); + let headers = with_traceparent_header(headers); + let request - request_with_headers(request, headers); + info!(?request, "sending request to NLP http service"); + let response = self + .client + .post(url) + .headers(headers) + .json(&request) + .send() + .await?; + match response.status() { + StatusCode::OK => OK(response.json().await?.into()), + _ => { + let code = response.status(); + let message = if let Ok(response) = response.json().await() { + response.message + } else { + "unknown error occured".into() + }; + Err(Error::Http {code, error}) + } + } + } + + pub async token_classification_task_predict( + &self, + request: TokenClassificationTaskRequest. + headers: HeaderMap, + ) -> Result { + let url = self.client.base_url().join("/api/v1/task/token-classification").unwrap(); + let headers = with_traceparent_header(headers); + info!(?request, "sending request to NLP http service"); + let response = self + .client + .post(url) + .headers(headers) + .json(&request) + .send() + .await?; + match response.status() { + StatusCode::OK => Ok(response.json::().await?.into()), + _ => { + let code = response.status(); + let message = if let Ok(response) = response.json::().await { + response.message + } else { + "unknown error occured".into() + }; + Err(Error::Http { code, message }) + } + } + + } + + pub async text_generation_task_predict( + &self, + request: TextGenerationTaskRequest, + headers: HeaderMap, + ) -> Result { + let url = self.client.base_url().join("/api/v1/task/text-generation").unwrap(); + let headers = with_traceparent_header(headers); + info!(?request, "sending request to NLP http service"); + let response = self + .client + .post(url) + .headers(headers) + .json(&request) + .send() + .await?; + match response.status() { + StatusCode::OK => Ok(response.json::().await?.into()), + _ => { + let code = response.status(); + let message = if let Ok(response) = response.json::().await { + response.message + } else { + "unknown error occured".into() + }; + Err(Error::Http { code, message }) + } + } + } + + pub async server_streaming_text_generation_task_predict( + &self, + request, ServerStreamingTextGenerationTaskRequest, + headers: HeaderMap, + ) -> Result { + let url = self.client.base_url().join("/api/v1/task/streaming-text-generation").unwrap(); + let headers = with_traceparent_header(headers); + info!(?request, "sending request to NLP http service"); + let response = self + .client + .post(url) + .headers(headers) + .json(&request) + .send() + .await?; + match response.status() { + StatusCode::OK => Ok(response.json::().await?.into()), + _ => { + let code = response.status(); + let message = if let Ok(response) = response.json::().await { + response.message + } else { + "unknown error occured".into() + }; + Err(Error::Http { code, message }) + } + } + } +} + +#[cfg_attr(test, faux::create)] +#[async_trait] +impl Client for NlpClientHttp { + fn name(&self) -> &str { + "nlp_http" + } + async fn health(&self) -> HealthCheckResult { + if let Some(health_client) = &self.health_client { + health_client.health().await + } else { + self.client.health().await + } + } +} \ No newline at end of file diff --git a/src/clients/openai.rs b/src/clients/openai.rs new file mode 100644 index 00000000..a626d111 --- /dev/null +++ b/src/clients/openai.rs @@ -0,0 +1,643 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +use std::{collections::HashMap, convert::Infallible}; + +use async_trait::async_trait; +use axum::response::sse; +use futures::StreamExt; +use hyper::{HeaderMap, StatusCode}; +use reqwest_eventsource::{Event, RequestBuilderExt}; +use serde::{Deserialize, Serialize}; +use tokio::sync::mpsc; +use tracing::{info, instrument}; + +use super::{create_http_client, Client, Error, HttpClient}; +use crate::{ + config::ServiceConfig, health::HealthCheckResult, tracing_utils::with_traceparent_header, +}; + +const DEFAULT_PORT: u16 = 8080; + +#[cfg_attr(test, faux::create)] +#[derive(Clone)] +pub struct OpenAiClient { + client: HttpClient, + health_client: Option, +} + +#[cfg_attr(test, faux::methods)] +impl OpenAiClient { + pub async fn new(config: &ServiceConfig, health_config: Option<&ServiceConfig>) -> Self { + let client = create_http_client(DEFAULT_PORT, config).await; + let health_client = if let Some(health_config) = health_config { + Some(create_http_client(DEFAULT_PORT, health_config).await) + } else { + None + }; + Self { + client, + health_client, + } + } + + #[instrument(skip_all, fields(request.model))] + pub async fn chat_completions( + &self, + request: ChatCompletionsRequest, + headers: HeaderMap, + ) -> Result { + let url = self.client.base_url().join("/v1/chat/completions").unwrap(); + let headers = with_traceparent_header(headers); + let stream = request.stream.unwrap_or_default(); + info!(?url, ?headers, ?request, "sending client request"); + if stream { + let (tx, rx) = mpsc::channel(32); + let mut event_stream = self + .client + .post(url) + .headers(headers) + .json(&request) + .eventsource() + .unwrap(); + // Spawn task to forward events to receiver + tokio::spawn(async move { + while let Some(result) = event_stream.next().await { + match result { + Ok(event) => { + if let Event::Message(message) = event { + let event = sse::Event::default().data(message.data); + let _ = tx.send(Ok(event)).await; + } + } + Err(reqwest_eventsource::Error::StreamEnded) => break, + Err(error) => { + // We received an error from the event stream, send an error event + let event = + sse::Event::default().event("error").data(error.to_string()); + let _ = tx.send(Ok(event)).await; + } + } + } + }); + Ok(ChatCompletionsResponse::Streaming(rx)) + } else { + let response = self + .client + .post(url) + .headers(headers) + .json(&request) + .send() + .await?; + match response.status() { + StatusCode::OK => Ok(response.json::().await?.into()), + _ => { + let code = response.status(); + let message = if let Ok(response) = response.json::().await { + response.message + } else { + "unknown error occurred".into() + }; + Err(Error::Http { code, message }) + } + } + } + } +} + +#[cfg_attr(test, faux::methods)] +#[async_trait] +impl Client for OpenAiClient { + fn name(&self) -> &str { + "openai" + } + + async fn health(&self) -> HealthCheckResult { + if let Some(health_client) = &self.health_client { + health_client.health().await + } else { + self.client.health().await + } + } +} + +#[derive(Debug)] +pub enum ChatCompletionsResponse { + Unary(ChatCompletion), + Streaming(mpsc::Receiver>), +} + +impl From for ChatCompletionsResponse { + fn from(value: ChatCompletion) -> Self { + Self::Unary(value) + } +} + +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +pub struct ChatCompletionsRequest { + /// A list of messages comprising the conversation so far. + pub messages: Vec, + /// ID of the model to use. + pub model: String, + /// Whether or not to store the output of this chat completion request. + #[serde(skip_serializing_if = "Option::is_none")] + pub store: Option, + /// Developer-defined tags and values. + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + /// Modify the likelihood of specified tokens appearing in the completion. + #[serde(skip_serializing_if = "Option::is_none")] + pub logit_bias: Option>, + /// Whether to return log probabilities of the output tokens or not. + /// If true, returns the log probabilities of each output token returned in the content of message. + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + /// An integer between 0 and 20 specifying the number of most likely tokens to return + /// at each token position, each with an associated log probability. + /// logprobs must be set to true if this parameter is used. + #[serde(skip_serializing_if = "Option::is_none")] + pub top_logprobs: Option, + /// The maximum number of tokens that can be generated in the chat completion. (DEPRECATED) + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + /// An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens. + #[serde(skip_serializing_if = "Option::is_none")] + pub max_completion_tokens: Option, + /// How many chat completion choices to generate for each input message. + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + /// Positive values penalize new tokens based on whether they appear in the text so far, + /// increasing the model's likelihood to talk about new topics. + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + /// An object specifying the format that the model must output. + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + /// If specified, our system will make a best effort to sample deterministically, + /// such that repeated requests with the same seed and parameters should return the same result. + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, + /// Specifies the latency tier to use for processing the request. + #[serde(skip_serializing_if = "Option::is_none")] + pub service_tier: Option, + /// Up to 4 sequences where the API will stop generating further tokens. + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, + /// If set, partial message deltas will be sent, like in ChatGPT. + /// Tokens will be sent as data-only server-sent events as they become available, + /// with the stream terminated by a data: [DONE] message. + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + /// Options for streaming response. Only set this when you set stream: true. + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + /// What sampling temperature to use, between 0 and 2. + /// Higher values like 0.8 will make the output more random, + /// while lower values like 0.2 will make it more focused and deterministic. + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + /// An alternative to sampling with temperature, called nucleus sampling, + /// where the model considers the results of the tokens with top_p probability mass. + /// So 0.1 means only the tokens comprising the top 10% probability mass are considered. + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + /// A list of tools the model may call. + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option, + /// Controls which (if any) tool is called by the model. + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + /// Whether to enable parallel function calling during tool use. + #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + /// A unique identifier representing your end-user. + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + + // Additional vllm params + #[serde(skip_serializing_if = "Option::is_none")] + pub best_of: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub use_beam_search: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub repetition_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub length_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub early_stopping: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub ignore_eos: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_token_ids: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub skip_special_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub spaces_between_special_tokens: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResponseFormat { + /// The type of response format being defined. + #[serde(rename = "type")] + pub r#type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub json_schema: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonSchema { + /// The name of the response format. + pub name: String, + /// A description of what the response format is for, used by the model to determine how to respond in the format. + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// The schema for the response format, described as a JSON Schema object. + #[serde(skip_serializing_if = "Option::is_none")] + pub schema: Option, + /// Whether to enable strict schema adherence when generating the output. + #[serde(skip_serializing_if = "Option::is_none")] + pub strict: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Tool { + /// The type of the tool. + #[serde(rename = "type")] + pub r#type: String, + pub function: ToolFunction, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolFunction { + /// The name of the function to be called. + pub name: String, + /// A description of what the function does, used by the model to choose when and how to call the function. + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// The parameters the functions accepts, described as a JSON Schema object. + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option, + /// Whether to enable strict schema adherence when generating the function call. + #[serde(skip_serializing_if = "Option::is_none")] + pub strict: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ToolChoice { + /// `none` means the model will not call any tool and instead generates a message. + /// `auto` means the model can pick between generating a message or calling one or more tools. + /// `required` means the model must call one or more tools. + String, + /// Specifies a tool the model should use. Use to force the model to call a specific function. + Object(ToolChoiceObject), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolChoiceObject { + /// The type of the tool. + #[serde(rename = "type")] + pub r#type: String, + pub function: Function, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StreamOptions { + /// If set, an additional chunk will be streamed before the data: [DONE] message. + /// The usage field on this chunk shows the token usage statistics for the entire + /// request, and the choices field will always be an empty array. All other chunks + /// will also include a usage field, but with a null value. + #[serde(skip_serializing_if = "Option::is_none")] + pub include_usage: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonSchemaObject { + pub id: String, + pub schema: String, + pub title: String, + pub description: Option, + #[serde(rename = "type")] + pub r#type: String, + pub properties: Option>, + pub required: Option>, +} + +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +pub struct Message { + /// The role of the messages author. + pub role: String, + /// The contents of the message. + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + /// An optional name for the participant. + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + /// The refusal message by the assistant. (assistant message only) + #[serde(skip_serializing_if = "Option::is_none")] + pub refusal: Option, + /// The tool calls generated by the model, such as function calls. (assistant message only) + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + /// Tool call that this message is responding to. (tool message only) + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum Content { + /// The text contents of the message. + Text(String), + /// Array of content parts. + Array(Vec), +} + +impl From for Content { + fn from(value: String) -> Self { + Content::Text(value) + } +} + +impl From<&str> for Content { + fn from(value: &str) -> Self { + Content::Text(value.to_string()) + } +} + +impl From> for Content { + fn from(value: Vec) -> Self { + Content::Array(value) + } +} + +impl From for ContentPart { + fn from(value: String) -> Self { + ContentPart { + r#type: ContentType::Text, + text: Some(value), + image_url: None, + refusal: None, + } + } +} + +impl From> for Content { + fn from(value: Vec) -> Self { + Content::Array(value.into_iter().map(|v| v.into()).collect()) + } +} + +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +pub enum ContentType { + #[serde(rename = "text")] + #[default] + Text, + #[serde(rename = "image_url")] + ImageUrl, +} + +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +pub struct ContentPart { + /// The type of the content part. + #[serde(rename = "type")] + pub r#type: ContentType, + /// Text content + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + /// Image content + #[serde(skip_serializing_if = "Option::is_none")] + pub image_url: Option, + /// The refusal message generated by the model. (assistant message only) + #[serde(skip_serializing_if = "Option::is_none")] + pub refusal: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImageUrl { + /// Either a URL of the image or the base64 encoded image data. + pub url: String, + /// Specifies the detail level of the image. + #[serde(skip_serializing_if = "Option::is_none")] + pub detail: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + /// The ID of the tool call. + pub id: String, + /// The type of the tool. + #[serde(rename = "type")] + pub r#type: String, + /// The function that the model called. + pub function: Function, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Function { + /// The name of the function to call. + pub name: String, + /// The arguments to call the function with, as generated by the model in JSON format. + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option>, +} + +/// Represents a chat completion response returned by model, based on the provided input. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletion { + /// A unique identifier for the chat completion. + pub id: String, + /// A list of chat completion choices. Can be more than one if n is greater than 1. + pub choices: Vec, + /// The Unix timestamp (in seconds) of when the chat completion was created. + pub created: i64, + /// The model used for the chat completion. + pub model: String, + /// The service tier used for processing the request. + /// This field is only included if the `service_tier` parameter is specified in the request. + #[serde(skip_serializing_if = "Option::is_none")] + pub service_tier: Option, + /// This fingerprint represents the backend configuration that the model runs with. + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, + /// The object type, which is always `chat.completion`. + pub object: String, + /// Usage statistics for the completion request. + pub usage: Usage, +} + +/// A chat completion choice. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionChoice { + /// The reason the model stopped generating tokens. + pub finish_reason: String, + /// The index of the choice in the list of choices. + pub index: usize, + /// A chat completion message generated by the model. + pub message: ChatCompletionMessage, + /// Log probability information for the choice. + pub logprobs: Option, +} + +/// A chat completion message generated by the model. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionMessage { + /// The contents of the message. + pub content: Option, + /// The refusal message generated by the model. + #[serde(skip_serializing_if = "Option::is_none")] + pub refusal: Option, + pub tool_calls: Vec, + /// The role of the author of this message. + pub role: String, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatCompletionLogprobs { + /// A list of message content tokens with log probability information. + pub content: Option>, + /// A list of message refusal tokens with log probability information. + pub refusal: Option>, +} + +/// Log probability information for a choice. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionLogprob { + /// The token. + pub token: String, + /// The log probability of this token. + pub logprob: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub bytes: Option>, + /// List of the most likely tokens and their log probability, at this token position. + #[serde(skip_serializing_if = "Option::is_none")] + pub top_logprobs: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionTopLogprob { + /// The token. + pub token: String, + /// The log probability of this token. + pub logprob: f32, +} + +/// Represents a streamed chunk of a chat completion response returned by model, based on the provided input. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionChunk { + /// A unique identifier for the chat completion. Each chunk has the same ID. + pub id: String, + /// A list of chat completion choices. + pub choices: Vec, + /// The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same timestamp. + pub created: i64, + /// The model to generate the completion. + pub model: String, + /// The service tier used for processing the request. + /// This field is only included if the service_tier parameter is specified in the request. + #[serde(skip_serializing_if = "Option::is_none")] + pub service_tier: Option, + /// This fingerprint represents the backend configuration that the model runs with. + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, + /// The object type, which is always `chat.completion.chunk`. + pub object: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionChunkChoice { + /// A chat completion delta generated by streamed model responses. + pub delta: ChatCompletionDelta, + /// Log probability information for the choice. + pub logprobs: Option, + /// The reason the model stopped generating tokens. + pub finish_reason: Option, + /// The index of the choice in the list of choices. + pub index: u32, +} + +/// A chat completion delta generated by streamed model responses. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionDelta { + /// The contents of the message. + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + /// The refusal message generated by the model. + #[serde(skip_serializing_if = "Option::is_none")] + pub refusal: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + pub tool_calls: Vec, + /// The role of the author of this message. + #[serde(skip_serializing_if = "Option::is_none")] + pub role: Option, +} + +/// Usage statistics for a completion. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Usage { + /// Number of tokens in the generated completion. + pub completion_tokens: u32, + /// Number of tokens in the prompt. + pub prompt_tokens: u32, + /// Total number of tokens used in the request (prompt + completion). + pub total_tokens: u32, + /// Breakdown of tokens used in a completion. + #[serde(skip_serializing_if = "Option::is_none")] + pub completion_token_details: Option, + /// Breakdown of tokens used in the prompt. + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_token_details: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CompletionTokenDetails { + pub audio_tokens: u32, + pub reasoning_tokens: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PromptTokenDetails { + pub audio_tokens: u32, + pub cached_tokens: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum StopTokens { + Array(Vec), + String(String), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpenAiError { + pub object: Option, + pub message: String, + #[serde(rename = "type")] + pub r#type: Option, + pub param: Option, + pub code: u16, +} diff --git a/src/clients/tgis.rs b/src/clients/tgis.rs index bc7412de..a11976b1 100644 --- a/src/clients/tgis.rs +++ b/src/clients/tgis.rs @@ -14,129 +14,125 @@ limitations under the License. */ -use std::collections::HashMap; +use async_trait::async_trait; use axum::http::HeaderMap; use futures::{StreamExt, TryStreamExt}; use ginepro::LoadBalancedChannel; use tonic::Code; +use tracing::{info, instrument}; -use super::{create_grpc_clients, BoxStream, ClientCode, Error}; +use super::{ + create_grpc_client, errors::grpc_to_http_code, grpc_request_with_headers, BoxStream, Client, + Error, +}; use crate::{ - clients::COMMON_ROUTER_KEY, config::ServiceConfig, - health::{HealthCheckResult, HealthProbe, HealthStatus}, + health::{HealthCheckResult, HealthStatus}, pb::fmaas::{ generation_service_client::GenerationServiceClient, BatchedGenerationRequest, BatchedGenerationResponse, BatchedTokenizeRequest, BatchedTokenizeResponse, GenerationResponse, ModelInfoRequest, ModelInfoResponse, SingleGenerationRequest, }, + tracing_utils::trace_context_from_grpc_response, }; -#[cfg_attr(test, faux::create, derive(Default))] +const DEFAULT_PORT: u16 = 8033; + +#[cfg_attr(test, faux::create)] #[derive(Clone)] pub struct TgisClient { - clients: HashMap>, -} - -#[cfg_attr(test, faux::methods)] -impl HealthProbe for TgisClient { - async fn health(&self) -> Result, Error> { - let mut results = HashMap::with_capacity(self.clients.len()); - for (model_id, mut client) in self.clients.clone() { - let response = client - .model_info(ModelInfoRequest { - model_id: "".into(), - }) - .await; - let code = match response { - Ok(_) => Code::Ok, - Err(status) if matches!(status.code(), Code::InvalidArgument | Code::NotFound) => { - Code::Ok - } - Err(status) => status.code(), - }; - let health_status = if matches!(code, Code::Ok) { - HealthStatus::Healthy - } else { - HealthStatus::Unhealthy - }; - results.insert( - model_id, - HealthCheckResult { - health_status, - response_code: ClientCode::Grpc(code), - reason: None, - }, - ); - } - Ok(results) - } + client: GenerationServiceClient, } #[cfg_attr(test, faux::methods)] impl TgisClient { - pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self { - let clients = create_grpc_clients(default_port, config, GenerationServiceClient::new).await; - Self { clients } - } - - fn client( - &self, - _model_id: &str, - ) -> Result, Error> { - // NOTE: We currently forward requests to the common-router, so we use a single client. - let model_id = COMMON_ROUTER_KEY; - Ok(self - .clients - .get(model_id) - .ok_or_else(|| Error::ModelNotFound { - model_id: model_id.to_string(), - })? - .clone()) + pub async fn new(config: &ServiceConfig) -> Self { + let client = create_grpc_client(DEFAULT_PORT, config, GenerationServiceClient::new).await; + Self { client } } + #[instrument(skip_all, fields(model_id = request.model_id))] pub async fn generate( &self, request: BatchedGenerationRequest, - _headers: HeaderMap, + headers: HeaderMap, ) -> Result { - let model_id = request.model_id.as_str(); - Ok(self.client(model_id)?.generate(request).await?.into_inner()) + let request = grpc_request_with_headers(request, headers); + info!(?request, "sending request to TGIS gRPC service"); + let mut client = self.client.clone(); + Ok(client.generate(request).await?.into_inner()) } + #[instrument(skip_all, fields(model_id = request.model_id))] pub async fn generate_stream( &self, request: SingleGenerationRequest, - _headers: HeaderMap, + headers: HeaderMap, ) -> Result>, Error> { - let model_id = request.model_id.as_str(); - let response_stream = self - .client(model_id)? - .generate_stream(request) - .await? - .into_inner() - .map_err(Into::into) - .boxed(); - Ok(response_stream) + let request = grpc_request_with_headers(request, headers); + info!(?request, "sending request to TGIS gRPC service"); + let mut client = self.client.clone(); + let response = client.generate_stream(request).await?; + trace_context_from_grpc_response(&response); + Ok(response.into_inner().map_err(Into::into).boxed()) } + #[instrument(skip_all, fields(model_id = request.model_id))] pub async fn tokenize( &self, request: BatchedTokenizeRequest, - _headers: HeaderMap, + headers: HeaderMap, ) -> Result { - let model_id = request.model_id.as_str(); - Ok(self.client(model_id)?.tokenize(request).await?.into_inner()) + info!(?request, "sending request to TGIS gRPC service"); + let mut client = self.client.clone(); + let request = grpc_request_with_headers(request, headers); + let response = client.tokenize(request).await?; + trace_context_from_grpc_response(&response); + Ok(response.into_inner()) } pub async fn model_info(&self, request: ModelInfoRequest) -> Result { - let model_id = request.model_id.as_str(); - Ok(self - .client(model_id)? - .model_info(request) - .await? - .into_inner()) + info!(?request, "sending request to TGIS gRPC service"); + let request = grpc_request_with_headers(request, HeaderMap::new()); + let mut client = self.client.clone(); + let response = client.model_info(request).await?; + trace_context_from_grpc_response(&response); + Ok(response.into_inner()) + } +} + +#[cfg_attr(test, faux::methods)] +#[async_trait] +impl Client for TgisClient { + fn name(&self) -> &str { + "tgis" + } + + async fn health(&self) -> HealthCheckResult { + let mut client = self.client.clone(); + let response = client + .model_info(ModelInfoRequest { + model_id: "".into(), + }) + .await; + let code = match response { + Ok(_) => Code::Ok, + Err(status) if matches!(status.code(), Code::InvalidArgument | Code::NotFound) => { + Code::Ok + } + Err(status) => status.code(), + }; + let status = if matches!(code, Code::Ok) { + HealthStatus::Healthy + } else { + HealthStatus::Unhealthy + }; + HealthCheckResult { + status, + code: grpc_to_http_code(code), + reason: None, + } } } diff --git a/src/config.rs b/src/config.rs index 1ebc7879..25e552a8 100644 --- a/src/config.rs +++ b/src/config.rs @@ -23,7 +23,7 @@ use std::{ use serde::Deserialize; use tracing::{debug, error, info, warn}; -use crate::clients::chunker::DEFAULT_MODEL_ID; +use crate::clients::{chunker::DEFAULT_CHUNKER_ID, is_valid_hostname}; // Placeholder to add default allowed headers const DEFAULT_ALLOWED_HEADERS: &[&str] = &[]; @@ -47,6 +47,10 @@ pub enum Error { detector_id: String, chunker_id: String, }, + #[error("invalid generation provider: {0}")] + InvalidGenerationProvider(String), + #[error("invalid hostname: {0}")] + InvalidHostname(String), } /// Configuration for service needed for @@ -84,14 +88,17 @@ pub struct TlsConfig { /// Generation service provider #[cfg_attr(test, derive(Default))] #[derive(Clone, Copy, Debug, Deserialize)] -#[serde(rename_all = "lowercase")] pub enum GenerationProvider { #[cfg_attr(test, default)] + #[serde(rename = "tgis")] Tgis, + #[serde(rename = "nlp")] Nlp, + #[serde(rename = "nlp-http")] + NlpHttp } -/// Generate service configuration +/// Generation service configuration #[cfg_attr(test, derive(Default))] #[derive(Clone, Debug, Deserialize)] pub struct GenerationConfig { @@ -101,6 +108,16 @@ pub struct GenerationConfig { pub service: ServiceConfig, } +/// Chat generation service configuration +#[cfg_attr(test, derive(Default))] +#[derive(Clone, Debug, Deserialize)] +pub struct ChatGenerationConfig { + /// Generation service connection information + pub service: ServiceConfig, + /// Generation health service connection information + pub health_service: Option, +} + /// Chunker parser type #[cfg_attr(test, derive(Default))] #[derive(Clone, Copy, Debug, Deserialize)] @@ -127,10 +144,26 @@ pub struct ChunkerConfig { pub struct DetectorConfig { /// Detector service connection information pub service: ServiceConfig, + /// Detector health service connection information + pub health_service: Option, /// ID of chunker that this detector will use pub chunker_id: String, /// Default threshold with which to filter detector results by score pub default_threshold: f64, + /// Type of detection this detector performs + #[serde(rename = "type")] + pub r#type: DetectorType, +} + +#[derive(Clone, Debug, Default, Deserialize)] +#[serde(rename_all = "snake_case")] +#[non_exhaustive] +pub enum DetectorType { + #[default] + TextContents, + TextGeneration, + TextChat, + TextContextDoc, } /// Overall orchestrator server configuration @@ -139,6 +172,8 @@ pub struct DetectorConfig { pub struct OrchestratorConfig { /// Generation service and associated configuration, can be omitted if configuring for generation is not wanted pub generation: Option, + /// Chat generation service and associated configuration, can be omitted if configuring for chat generation is not wanted + pub chat_generation: Option, /// Chunker services and associated configurations, if omitted the default value "whole_doc_chunker" is used pub chunkers: Option>, /// Detector services and associated configurations @@ -206,6 +241,10 @@ impl OrchestratorConfig { if let Some(generation) = &mut self.generation { apply_named_tls_config(&mut generation.service, tls_configs)?; } + // Chat generation + if let Some(chat_generation) = &mut self.chat_generation { + apply_named_tls_config(&mut chat_generation.service, tls_configs)?; + } // Chunkers if let Some(chunkers) = &mut self.chunkers { for chunker in chunkers.values_mut() { @@ -221,25 +260,84 @@ impl OrchestratorConfig { } fn validate(&self) -> Result<(), Error> { + // Detectors are configured if self.detectors.is_empty() { - Err(Error::NoDetectorsConfigured) - } else { - for (detector_id, detector) in &self.detectors { - // Chunker is valid - let valid_chunker = detector.chunker_id == DEFAULT_MODEL_ID - || self - .chunkers - .as_ref() - .is_some_and(|chunkers| chunkers.contains_key(&detector.chunker_id)); - if !valid_chunker { - return Err(Error::DetectorChunkerNotFound { - detector_id: detector_id.clone(), - chunker_id: detector.chunker_id.clone(), - }); + return Err(Error::NoDetectorsConfigured); + } + + // Apply validation rules + self.validate_generation_config()?; + self.validate_chat_generation_config()?; + self.validate_detector_configs()?; + self.validate_chunker_configs()?; + + Ok(()) + } + + /// Validates generation config. + fn validate_generation_config(&self) -> Result<(), Error> { + if let Some(generation) = &self.generation { + // Hostname is valid + if !is_valid_hostname(&generation.service.hostname) { + return Err(Error::InvalidHostname( + "`generation` has an invalid hostname".into(), + )); + } + } + Ok(()) + } + + /// Validates chat generation config. + fn validate_chat_generation_config(&self) -> Result<(), Error> { + if let Some(chat_generation) = &self.chat_generation { + // Hostname is valid + if !is_valid_hostname(&chat_generation.service.hostname) { + return Err(Error::InvalidHostname( + "`chat_generation` has an invalid hostname".into(), + )); + } + } + Ok(()) + } + + /// Validates detector configs. + fn validate_detector_configs(&self) -> Result<(), Error> { + for (detector_id, detector) in &self.detectors { + // Hostname is valid + if !is_valid_hostname(&detector.service.hostname) { + return Err(Error::InvalidHostname(format!( + "detector `{detector_id}` has an invalid hostname" + ))); + } + // Chunker is valid + let valid_chunker = detector.chunker_id == DEFAULT_CHUNKER_ID + || self + .chunkers + .as_ref() + .is_some_and(|chunkers| chunkers.contains_key(&detector.chunker_id)); + if !valid_chunker { + return Err(Error::DetectorChunkerNotFound { + detector_id: detector_id.clone(), + chunker_id: detector.chunker_id.clone(), + }); + } + } + Ok(()) + } + + /// Validates chunker configs. + fn validate_chunker_configs(&self) -> Result<(), Error> { + if let Some(chunkers) = &self.chunkers { + for (chunker_id, chunker) in chunkers { + // Hostname is valid + if !is_valid_hostname(&chunker.service.hostname) { + return Err(Error::InvalidHostname(format!( + "chunker `{chunker_id}` has an invalid hostname" + ))); } } - Ok(()) } + Ok(()) } /// Get ID of chunker associated with a particular detector @@ -301,6 +399,7 @@ chunkers: port: 9000 detectors: hap-en: + type: text_contents service: hostname: localhost port: 9000 @@ -341,6 +440,7 @@ chunkers: port: 9000 detectors: hap: + type: text_contents service: hostname: localhost port: 9000 @@ -393,6 +493,7 @@ chunkers: port: 9000 detectors: hap: + type: text_contents service: hostname: localhost port: 9000 @@ -487,6 +588,7 @@ chunkers: port: 9000 detectors: hap: + type: text_contents service: hostname: localhost port: 9000 @@ -527,6 +629,7 @@ chunkers: port: 9000 detectors: hap: + type: text_contents service: hostname: localhost port: 9000 @@ -543,10 +646,9 @@ tls: config .apply_named_tls_configs() .expect("Apply named TLS configs should have succeeded"); - let error = config + config .validate() .expect_err("Config should not have been validated"); - assert!(matches!(error, Error::DetectorChunkerNotFound { .. })) } #[test] @@ -565,12 +667,13 @@ chunkers: port: 9000 detectors: hap: + type: text_contents service: hostname: localhost port: 9000 tls: detector chunker_id: sentence-fr - default_threshold: 0.5 + default_threshold: 0.5 "#; let config: OrchestratorConfig = serde_yml::from_str(s).unwrap(); assert!(config.passthrough_headers.is_empty()); @@ -592,12 +695,13 @@ chunkers: port: 9000 detectors: hap: + type: text_contents service: hostname: localhost port: 9000 tls: detector chunker_id: sentence-fr - default_threshold: 0.5 + default_threshold: 0.5 passthrough_headers: - test-header diff --git a/src/health.rs b/src/health.rs index af43d94e..8ee56dac 100644 --- a/src/health.rs +++ b/src/health.rs @@ -1,218 +1,128 @@ -use std::{collections::HashMap, fmt::Display, sync::Arc}; +use std::{collections::HashMap, fmt::Display}; -use axum::{ - http::StatusCode, - response::{IntoResponse, Response}, - Json, -}; -use serde::{ser::SerializeStruct, Deserialize, Serialize}; -use tokio::sync::RwLock; -use tonic::Code; -use tracing::{error, warn}; +use axum::http::StatusCode; +use serde::{Deserialize, Serialize}; use crate::{ - clients::{ClientCode, Error}, - pb::grpc::health::v1::HealthCheckResponse, + clients::errors::grpc_to_http_code, + pb::grpc::health::v1::{health_check_response::ServingStatus, HealthCheckResponse}, }; -/// A health check endpoint for a singular client. -/// NOTE: Only implemented by HTTP clients, gRPC clients with health check support should use the generated `grpc::health::v1::health_client::HealthClient` service. -pub trait HealthCheck { - /// Makes a request to the client service health check endpoint and turns result into a `HealthCheckResult`. - fn check(&self) -> impl std::future::Future + Send; -} - -/// A health probe for aggregated health check results of multiple client services. -pub trait HealthProbe { - /// Makes a health check request to each client and returns a map of client service ids to health check results. - fn health( - &self, - ) -> impl std::future::Future, Error>> + Send; -} - /// Health status determined for or returned by a client service. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "UPPERCASE")] pub enum HealthStatus { - /// The service is healthy and should be considered ready to serve requests. - #[serde(rename = "HEALTHY")] + /// The service status is healthy. Healthy, - /// The service is unhealthy and should be considered not ready to serve requests. - #[serde(rename = "UNHEALTHY")] + /// The service status is unhealthy. Unhealthy, - /// The health status of the service (and possibly the service itself) is unknown. - /// The health check response indicated the service's health is unknown or the health request failed in a way that could have been a misconfiguration, - /// meaning the actual service could still be healthy. - #[serde(rename = "UNKNOWN")] + /// The service status is unknown. Unknown, } -/// An optional response body that can be interpreted from an HTTP health check response. -/// This is a minimal contract that allows HTTP health requests to opt in to more detailed health check responses than just the status code. -/// If the body omitted, the health check response is considered successful if the status code is `HTTP 200 OK`. -#[derive(serde::Deserialize)] -pub struct OptionalHealthCheckResponseBody { - /// `HEALTHY`, `UNHEALTHY`, or `UNKNOWN`. Although `HEALTHY` is already implied without a body. - pub health_status: HealthStatus, - /// Optional reason for the health check result status being `UNHEALTHY` or `UNKNOWN`. - /// May be omitted overall if the health check was successful. - #[serde(default)] - pub reason: Option, +impl Display for HealthStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + HealthStatus::Healthy => write!(f, "HEALTHY"), + HealthStatus::Unhealthy => write!(f, "UNHEALTHY"), + HealthStatus::Unknown => write!(f, "UNKNOWN"), + } + } } -/// Result of a health check request. -#[derive(Debug, Clone)] -pub struct HealthCheckResult { - /// Overall health status of client service. - /// `HEALTHY`, `UNHEALTHY`, or `UNKNOWN`. - pub health_status: HealthStatus, - /// Response code of the latest health check request. - /// This should be omitted on serialization if the health check was successful (when the response is `HTTP 200 OK` or `gRPC 0 OK`). - pub response_code: ClientCode, - /// Optional reason for the health check result status being `UNHEALTHY` or `UNKNOWN`. - /// May be omitted overall if the health check was successful. - pub reason: Option, +impl From for HealthStatus { + fn from(value: HealthCheckResponse) -> Self { + match value.status() { + ServingStatus::Serving => Self::Healthy, + ServingStatus::NotServing => Self::Unhealthy, + ServingStatus::Unknown | ServingStatus::ServiceUnknown => Self::Unknown, + } + } +} + +impl From for HealthStatus { + fn from(code: StatusCode) -> Self { + match code.as_u16() { + 200..=299 => Self::Healthy, + 500..=599 => Self::Unhealthy, + _ => Self::Unknown, + } + } } /// A cache to hold the latest health check results for each client service. /// Orchestrator has a reference-counted mutex-protected instance of this cache. #[derive(Debug, Clone, Default, Serialize)] -pub struct HealthCheckCache { - pub detectors: HashMap, - pub chunkers: HashMap, - pub generation: HashMap, -} - -/// Response for the readiness probe endpoint that holds a serialized cache of health check results for each client service. -#[derive(Debug, Clone, Serialize)] -pub struct HealthProbeResponse { - pub services: HealthCheckCache, -} +pub struct HealthCheckCache(HashMap); -/// Query param for triggering the client health check probe on the `/info` endpoint. -#[derive(Debug, Clone, Deserialize)] -pub struct HealthCheckProbeParams { - /// Whether to probe the client services' health checks or just return the cached health status. - #[serde(default)] - pub probe: bool, -} +impl HealthCheckCache { + pub fn new() -> Self { + Self(HashMap::new()) + } -impl HealthCheckResult { - pub fn reason_from_health_check_response(response: &HealthCheckResponse) -> Option { - match response.status { - 0 => Some("from gRPC health check serving status: UNKNOWN".to_string()), - 1 => None, - 2 => Some("from gRPC health check serving status: NOT_SERVING".to_string()), - 3 => Some("from gRPC health check serving status: SERVICE_UNKNOWN".to_string()), - _ => { - error!( - "Unexpected gRPC health check serving status: {}", - response.status - ); - Some(format!( - "Unexpected gRPC health check serving status: {}", - response.status - )) - } - } + pub fn with_capacity(capacity: usize) -> Self { + Self(HashMap::with_capacity(capacity)) } -} -impl HealthCheckCache { - pub fn is_initialized(&self) -> bool { - !self.detectors.is_empty() && !self.chunkers.is_empty() && !self.generation.is_empty() + /// Returns `true` if all services are healthy or unknown. + pub fn healthy(&self) -> bool { + !self + .0 + .iter() + .any(|(_, value)| matches!(value.status, HealthStatus::Unhealthy)) } } -impl HealthProbeResponse { - pub async fn from_cache(cache: Arc>) -> Self { - let services = cache.read().await.clone(); - Self { services } +impl std::ops::Deref for HealthCheckCache { + type Target = HashMap; + + fn deref(&self) -> &Self::Target { + &self.0 } } -impl Display for HealthStatus { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - HealthStatus::Healthy => write!(f, "HEALTHY"), - HealthStatus::Unhealthy => write!(f, "UNHEALTHY"), - HealthStatus::Unknown => write!(f, "UNKNOWN"), - } +impl std::ops::DerefMut for HealthCheckCache { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 } } impl Display for HealthCheckCache { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut services = vec![]; - let mut detectors = vec![]; - let mut chunkers = vec![]; - let mut generation = vec![]; - for (service, result) in &self.detectors { - detectors.push(format!("\t\t{}: {}", service, result)); - } - for (service, result) in &self.chunkers { - chunkers.push(format!("\t\t{}: {}", service, result)); - } - for (service, result) in &self.generation { - generation.push(format!("\t\t{}: {}", service, result)); - } - if !self.detectors.is_empty() { - services.push(format!("\tdetectors: {{\n{}\t}}", detectors.join(",\n"))); - } - if !self.chunkers.is_empty() { - services.push(format!("\tchunkers: {{\n{}\t}}", chunkers.join(",\n"))); - } - if !self.generation.is_empty() { - services.push(format!("\tgeneration: {{\n{}\t}}", generation.join(",\n"))); - } - write!( - f, - "configured client services: {{\n{}\n}}", - services.join(",\n") - ) + write!(f, "{}", serde_json::to_string_pretty(self).unwrap()) } } -impl Display for HealthProbeResponse { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.services) +impl HealthCheckResponse { + pub fn reason(&self) -> Option { + let status = self.status(); + match status { + ServingStatus::Serving => None, + _ => Some(status.as_str_name().to_string()), + } } } -impl Serialize for HealthCheckResult { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - match self.health_status { - HealthStatus::Healthy => self.health_status.serialize(serializer), - _ => match &self.reason { - Some(reason) => { - let mut state = serializer.serialize_struct("HealthCheckResult", 3)?; - state.serialize_field("health_status", &self.health_status)?; - state.serialize_field("response_code", &self.response_code.to_string())?; - state.serialize_field("reason", reason)?; - state.end() - } - None => { - let mut state = serializer.serialize_struct("HealthCheckResult", 2)?; - state.serialize_field("health_status", &self.health_status)?; - state.serialize_field("response_code", &self.response_code.to_string())?; - state.end() - } - }, - } - } +/// Result of a health check request. +#[derive(Debug, Clone, Serialize)] +pub struct HealthCheckResult { + /// Overall health status of client service. + pub status: HealthStatus, + /// Response code of the latest health check request. + #[serde( + with = "http_serde::status_code", + skip_serializing_if = "StatusCode::is_success" + )] + pub code: StatusCode, + #[serde(skip_serializing_if = "Option::is_none")] + pub reason: Option, } impl Display for HealthCheckResult { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match &self.reason { - Some(reason) => write!( - f, - "{} ({})\n\t\t\t{}", - self.health_status, self.response_code, reason - ), - None => write!(f, "{} ({})", self.health_status, self.response_code), + Some(reason) => write!(f, "{} ({})\n\t\t\t{}", self.status, self.code, reason), + None => write!(f, "{} ({})", self.status, self.code), } } } @@ -223,63 +133,29 @@ impl From, tonic::Status>> for Healt Ok(response) => { let response = response.into_inner(); Self { - health_status: response.into(), - response_code: ClientCode::Grpc(Code::Ok), - reason: Self::reason_from_health_check_response(&response), + status: response.into(), + code: StatusCode::OK, + reason: response.reason(), } } Err(status) => Self { - health_status: HealthStatus::Unknown, - response_code: ClientCode::Grpc(status.code()), - reason: Some(format!("gRPC health check failed: {}", status)), + status: HealthStatus::Unknown, + code: grpc_to_http_code(status.code()), + reason: Some(status.message().to_string()), }, } } } -impl From for HealthStatus { - fn from(value: HealthCheckResponse) -> Self { - // NOTE: gRPC Health v1 status codes: 0 = UNKNOWN, 1 = SERVING, 2 = NOT_SERVING, 3 = SERVICE_UNKNOWN - match value.status { - 1 => Self::Healthy, - 2 => Self::Unhealthy, - _ => Self::Unknown, - } - } -} - -impl From for HealthStatus { - fn from(code: StatusCode) -> Self { - match code.as_u16() { - 200 => Self::Healthy, - 201..=299 => { - warn!( - "Unexpected HTTP successful health check response status code: {}", - code - ); - Self::Healthy - } - 503 => Self::Unhealthy, - 500..=502 | 504..=599 => { - warn!( - "Unexpected HTTP server error health check response status code: {}", - code - ); - Self::Unhealthy - } - _ => { - warn!( - "Unexpected HTTP client error health check response status code: {}", - code - ); - Self::Unknown - } - } - } -} - -impl IntoResponse for HealthProbeResponse { - fn into_response(self) -> Response { - (StatusCode::OK, Json(self)).into_response() - } +/// An optional response body that can be interpreted from an HTTP health check response. +/// This is a minimal contract that allows HTTP health requests to opt in to more detailed health check responses than just the status code. +/// If the body omitted, the health check response is considered successful if the status code is `HTTP 200 OK`. +#[derive(Deserialize)] +pub struct OptionalHealthCheckResponseBody { + /// `HEALTHY`, `UNHEALTHY`, or `UNKNOWN`. Although `HEALTHY` is already implied without a body. + pub status: HealthStatus, + /// Optional reason for the health check result status being `UNHEALTHY` or `UNKNOWN`. + /// May be omitted overall if the health check was successful. + #[serde(default)] + pub reason: Option, } diff --git a/src/lib.rs b/src/lib.rs index bfff8695..efc18bc0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,8 +15,9 @@ */ -#![allow(clippy::iter_kv_map, clippy::enum_variant_names)] +#![allow(clippy::iter_kv_map, clippy::enum_variant_names, async_fn_in_trait)] +pub mod args; mod clients; pub mod config; pub mod health; @@ -24,3 +25,4 @@ mod models; pub mod orchestrator; mod pb; pub mod server; +pub mod tracing_utils; diff --git a/src/main.rs b/src/main.rs index 21a5c5c0..0cc174c2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,39 +15,12 @@ */ -use std::{ - net::{IpAddr, Ipv4Addr, SocketAddr}, - path::PathBuf, -}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use clap::Parser; -use fms_guardrails_orchestr8::{config::OrchestratorConfig, orchestrator::Orchestrator, server}; -use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; - -#[derive(Parser, Debug)] -#[clap(author, version, about, long_about = None)] -struct Args { - #[clap(default_value = "8033", long, env)] - http_port: u16, - #[clap(default_value = "8034", long, env)] - health_http_port: u16, - #[clap(long, env)] - json_output: bool, - #[clap( - default_value = "config/config.yaml", - long, - env = "ORCHESTRATOR_CONFIG" - )] - config_path: PathBuf, - #[clap(long, env)] - tls_cert_path: Option, - #[clap(long, env)] - tls_key_path: Option, - #[clap(long, env)] - tls_client_ca_cert_path: Option, - #[clap(default_value = "false", long, env)] - start_up_health_check: bool, -} +use fms_guardrails_orchestr8::{ + args::Args, config::OrchestratorConfig, orchestrator::Orchestrator, server, tracing_utils, +}; fn main() -> Result<(), anyhow::Error> { rustls::crypto::aws_lc_rs::default_provider() @@ -62,14 +35,6 @@ fn main() -> Result<(), anyhow::Error> { panic!("tls: cannot provide client ca cert without keypair") } - let filter = EnvFilter::try_from_default_env() - .unwrap_or(EnvFilter::new("INFO")) - .add_directive("ginepro=info".parse().unwrap()); - tracing_subscriber::registry() - .with(filter) - .with(tracing_subscriber::fmt::layer()) - .init(); - let http_addr: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), args.http_port); let health_http_addr: SocketAddr = @@ -81,6 +46,7 @@ fn main() -> Result<(), anyhow::Error> { .build() .unwrap() .block_on(async { + let trace_shutdown = tracing_utils::init_tracing(args.clone().into())?; let config = OrchestratorConfig::load(args.config_path).await?; let orchestrator = Orchestrator::new(config, args.start_up_health_check).await?; @@ -93,6 +59,6 @@ fn main() -> Result<(), anyhow::Error> { orchestrator, ) .await?; - Ok(()) + Ok(trace_shutdown()?) }) } diff --git a/src/models.rs b/src/models.rs index 1eb764ce..68b790a0 100644 --- a/src/models.rs +++ b/src/models.rs @@ -22,10 +22,29 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; use crate::{ - clients::detector::{ContentAnalysisResponse, ContextType}, + clients::{ + self, + detector::{ContentAnalysisResponse, ContextType}, + openai::{Content, ContentType}, + }, + health::HealthCheckCache, pb, }; +pub const THRESHOLD_PARAM: &str = "threshold"; + +#[derive(Clone, Debug, Serialize)] +pub struct InfoResponse { + pub services: HealthCheckCache, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct InfoParams { + /// Whether to probe the client services' health checks or just return the latest health status. + #[serde(default)] + pub probe: bool, +} + /// Parameters relevant to each detector #[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)] pub struct DetectorParams(HashMap); @@ -37,8 +56,8 @@ impl DetectorParams { } /// Threshold to filter detector results by score. - pub fn threshold(&self) -> Option { - self.0.get("threshold").and_then(|v| v.as_f64()) + pub fn pop_threshold(&mut self) -> Option { + self.0.remove(THRESHOLD_PARAM).and_then(|v| v.as_f64()) } } @@ -926,6 +945,79 @@ pub struct ContextDocsResult { pub detections: Vec, } +/// The request format expected in the /api/v2/text/detect/generated endpoint. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ChatDetectionHttpRequest { + /// The map of detectors to be used, along with their respective parameters, e.g. thresholds. + pub detectors: HashMap, + + // The list of messages to run detections on. + pub messages: Vec, +} + +impl ChatDetectionHttpRequest { + /// Upfront validation of user request + pub fn validate(&self) -> Result<(), ValidationError> { + // Validate required parameters + if self.detectors.is_empty() { + return Err(ValidationError::Required("detectors".into())); + } + if self.messages.is_empty() { + return Err(ValidationError::Required("messages".into())); + } + + Ok(()) + } + + /// Validates for the "/api/v1/text/chat" endpoint. + pub fn validate_for_text(&self) -> Result<(), ValidationError> { + self.validate()?; + self.validate_messages()?; + validate_detector_params(&self.detectors)?; + + Ok(()) + } + + /// Validates if message contents are either a string or a content type of type "text" + fn validate_messages(&self) -> Result<(), ValidationError> { + for message in &self.messages { + match &message.content { + Some(content) => self.validate_content_type(content)?, + None => { + return Err(ValidationError::Invalid( + "Message content cannot be empty".into(), + )) + } + } + } + Ok(()) + } + + /// Validates if content type array contains only text messages + fn validate_content_type(&self, content: &Content) -> Result<(), ValidationError> { + match content { + Content::Array(content) => { + for content_part in content { + if !matches!(content_part.r#type, ContentType::Text) { + return Err(ValidationError::Invalid( + "Only content of type text is allowed".into(), + )); + } + } + Ok(()) + } + Content::Text(_) => Ok(()), // if message.content is a string, it is a valid message + } + } +} + +/// The response format of the /api/v2/text/detection/chat endpoint +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ChatDetectionResult { + /// Detection results + pub detections: Vec, +} + /// The request format expected in the /api/v2/text/detect/generated endpoint. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct DetectionOnGeneratedHttpRequest { @@ -1182,10 +1274,12 @@ mod tests { { "threshold": 0.2 }"#; - let value: DetectorParams = serde_json::from_str(value_json)?; - assert_eq!(value.threshold(), Some(0.2)); - let value = DetectorParams::new(); - assert_eq!(value.threshold(), None); + let mut value: DetectorParams = serde_json::from_str(value_json)?; + assert_eq!(value.pop_threshold(), Some(0.2)); + assert!(!value.contains_key("threshold")); + let mut value = DetectorParams::new(); + assert!(!value.contains_key("threshold")); + assert_eq!(value.pop_threshold(), None); Ok(()) } } diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 607ae96a..a6df6ce0 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -17,27 +17,34 @@ pub mod errors; pub use errors::Error; +pub mod chat_completions_detection; pub mod streaming; pub mod unary; use std::{collections::HashMap, sync::Arc}; use axum::http::header::HeaderMap; +use opentelemetry::trace::TraceId; use tokio::{sync::RwLock, time::Instant}; use tracing::{debug, info}; -use uuid::Uuid; use crate::{ clients::{ - self, detector::ContextType, ChunkerClient, DetectorClient, GenerationClient, NlpClient, - TgisClient, COMMON_ROUTER_KEY, + self, + chunker::ChunkerClient, + detector::{ + text_context_doc::ContextType, TextChatDetectorClient, TextContextDocDetectorClient, + TextGenerationDetectorClient, + }, + openai::{ChatCompletionsRequest, OpenAiClient}, + ClientMap, GenerationClient, NlpClient, TextContentsDetectorClient, TgisClient, NlpHttpClient }, - config::{GenerationProvider, OrchestratorConfig}, - health::{HealthCheckCache, HealthProbe, HealthProbeResponse}, + config::{DetectorType, GenerationProvider, OrchestratorConfig}, + health::HealthCheckCache, models::{ - ContextDocsHttpRequest, DetectionOnGeneratedHttpRequest, DetectorParams, - GenerationWithDetectionHttpRequest, GuardrailsConfig, GuardrailsHttpRequest, - GuardrailsTextGenerationParameters, TextContentDetectionHttpRequest, + ChatDetectionHttpRequest, ContextDocsHttpRequest, DetectionOnGeneratedHttpRequest, + DetectorParams, GenerationWithDetectionHttpRequest, GuardrailsConfig, + GuardrailsHttpRequest, GuardrailsTextGenerationParameters, TextContentDetectionHttpRequest, }, }; @@ -48,16 +55,20 @@ const UNSUITABLE_INPUT_MESSAGE: &str = "Unsuitable input detected. \ #[cfg_attr(test, derive(Default))] pub struct Context { config: OrchestratorConfig, - generation_client: GenerationClient, - chunker_client: ChunkerClient, - detector_client: DetectorClient, + clients: ClientMap, +} + +impl Context { + pub fn new(config: OrchestratorConfig, clients: ClientMap) -> Self { + Self { config, clients } + } } /// Handles orchestrator tasks. #[cfg_attr(test, derive(Default))] pub struct Orchestrator { ctx: Arc, - client_health_cache: Arc>, + client_health: Arc>, } impl Orchestrator { @@ -65,16 +76,11 @@ impl Orchestrator { config: OrchestratorConfig, start_up_health_check: bool, ) -> Result { - let (generation_client, chunker_client, detector_client) = create_clients(&config).await; - let ctx = Arc::new(Context { - config, - generation_client, - chunker_client, - detector_client, - }); + let clients = create_clients(&config).await; + let ctx = Arc::new(Context { config, clients }); let orchestrator = Self { ctx, - client_health_cache: Arc::new(RwLock::new(HealthCheckCache::default())), + client_health: Arc::new(RwLock::new(HealthCheckCache::default())), }; debug!("running start up checks"); orchestrator.on_start_up(start_up_health_check).await?; @@ -92,37 +98,34 @@ impl Orchestrator { pub async fn on_start_up(&self, health_check: bool) -> Result<(), Error> { info!("Performing start-up actions for orchestrator..."); if health_check { - info!("Probing health status of configured clients..."); - // Run probe, update cache - let res = self.clients_health(true).await.unwrap_or_else(|e| { - // Panic for unexpected behaviour as there are currently no errors propagated to here. - panic!("Unexpected error during client health probing: {}", e); - }); + info!("Probing client health..."); + let client_health = self.client_health(true).await; // Results of probe do not affect orchestrator start-up. - info!("Orchestrator client health probe results:\n{}", res); + info!("Client health:\n{client_health}"); } Ok(()) } - pub async fn clients_health(&self, probe: bool) -> Result { - let initialized = self.client_health_cache.read().await.is_initialized(); + /// Returns client health state. + pub async fn client_health(&self, probe: bool) -> HealthCheckCache { + let initialized = !self.client_health.read().await.is_empty(); if probe || !initialized { debug!("refreshing health cache"); let now = Instant::now(); - let detectors = self.ctx.detector_client.health().await?; - let chunkers = self.ctx.chunker_client.health().await?; - let generation = self.ctx.generation_client.health().await?; - let mut health_cache = self.client_health_cache.write().await; - health_cache.detectors = detectors; - health_cache.chunkers = chunkers; - health_cache.generation = generation; + let mut health = HealthCheckCache::with_capacity(self.ctx.clients.len()); + // TODO: perform health checks concurrently? + for (key, client) in self.ctx.clients.iter() { + let result = client.health().await; + health.insert(key.into(), result); + } + let mut client_health = self.client_health.write().await; + *client_health = health; debug!( "refreshing health cache completed in {:.2?}ms", now.elapsed().as_millis() ); } - - Ok(HealthProbeResponse::from_cache(self.client_health_cache.clone()).await) + self.client_health.read().await.clone() } } @@ -162,50 +165,94 @@ fn get_chunker_ids( .collect::, Error>>() } -async fn create_clients( - config: &OrchestratorConfig, -) -> (GenerationClient, ChunkerClient, DetectorClient) { - // TODO: create better solution for routers - let generation_client = match &config.generation { - Some(generation) => match &generation.provider { +async fn create_clients(config: &OrchestratorConfig) -> ClientMap { + let mut clients = ClientMap::new(); + + // Create generation client + if let Some(generation) = &config.generation { + match generation.provider { GenerationProvider::Tgis => { - let client = TgisClient::new( - clients::DEFAULT_TGIS_PORT, - &[(COMMON_ROUTER_KEY.to_string(), generation.service.clone())], - ) - .await; - GenerationClient::tgis(client) + let tgis_client = TgisClient::new(&generation.service).await; + let generation_client = GenerationClient::tgis(tgis_client); + clients.insert("generation".to_string(), generation_client); } GenerationProvider::Nlp => { - let client = NlpClient::new( - clients::DEFAULT_CAIKIT_NLP_PORT, - &[(COMMON_ROUTER_KEY.to_string(), generation.service.clone())], - ) - .await; - GenerationClient::nlp(client) + let nlp_client = NlpClient::new(&generation.service).await; + let generation_client = GenerationClient::nlp(nlp_client); + clients.insert("generation".to_string(), generation_client); } - }, - None => GenerationClient::not_configured(), - }; - // TODO: simplify all of this - let chunker_config = match &config.chunkers { - Some(chunkers) => chunkers - .iter() - .map(|(chunker_id, config)| (chunker_id.clone(), config.service.clone())) - .collect::>(), - None => vec![], - }; - let chunker_client = ChunkerClient::new(clients::DEFAULT_CHUNKER_PORT, &chunker_config).await; - - let detector_config = config - .detectors - .iter() - .map(|(detector_id, config)| (detector_id.clone(), config.service.clone())) - .collect::>(); - let detector_client = - DetectorClient::new(clients::DEFAULT_DETECTOR_PORT, &detector_config).await; - - (generation_client, chunker_client, detector_client) + GenerationProvider::NlpClientHttp => { + let nlp_client_http = NlpClientHttp::new(&generation.service).await; + let generation_client = GenerationClient::nlp_http(nlp_client_http); + clients.insert("generation".to_string(), generation_client); + } + } + } + + // Create chat generation client + if let Some(chat_generation) = &config.chat_generation { + let openai_client = OpenAiClient::new( + &chat_generation.service, + chat_generation.health_service.as_ref(), + ) + .await; + clients.insert("chat_generation".to_string(), openai_client); + } + + // Create chunker clients + if let Some(chunkers) = &config.chunkers { + for (chunker_id, chunker) in chunkers { + let chunker_client = ChunkerClient::new(&chunker.service).await; + clients.insert(chunker_id.to_string(), chunker_client); + } + } + + // Create detector clients + for (detector_id, detector) in &config.detectors { + match detector.r#type { + DetectorType::TextContents => { + clients.insert( + detector_id.into(), + TextContentsDetectorClient::new( + &detector.service, + detector.health_service.as_ref(), + ) + .await, + ); + } + DetectorType::TextGeneration => { + clients.insert( + detector_id.into(), + TextGenerationDetectorClient::new( + &detector.service, + detector.health_service.as_ref(), + ) + .await, + ); + } + DetectorType::TextChat => { + clients.insert( + detector_id.into(), + TextChatDetectorClient::new( + &detector.service, + detector.health_service.as_ref(), + ) + .await, + ); + } + DetectorType::TextContextDoc => { + clients.insert( + detector_id.into(), + TextContextDocDetectorClient::new( + &detector.service, + detector.health_service.as_ref(), + ) + .await, + ); + } + } + } + clients } #[derive(Debug, Clone)] @@ -216,7 +263,7 @@ pub struct Chunk { #[derive(Debug)] pub struct ClassificationWithGenTask { - pub request_id: Uuid, + pub trace_id: TraceId, pub model_id: String, pub inputs: String, pub guardrails_config: GuardrailsConfig, @@ -225,9 +272,9 @@ pub struct ClassificationWithGenTask { } impl ClassificationWithGenTask { - pub fn new(request_id: Uuid, request: GuardrailsHttpRequest, headers: HeaderMap) -> Self { + pub fn new(trace_id: TraceId, request: GuardrailsHttpRequest, headers: HeaderMap) -> Self { Self { - request_id, + trace_id, model_id: request.model_id, inputs: request.inputs, guardrails_config: request.guardrail_config.unwrap_or_default(), @@ -240,8 +287,8 @@ impl ClassificationWithGenTask { /// Task for the /api/v2/text/detection/content endpoint #[derive(Debug)] pub struct GenerationWithDetectionTask { - /// Request unique identifier - pub request_id: Uuid, + /// Unique identifier of request trace + pub trace_id: TraceId, /// Model ID of the LLM pub model_id: String, @@ -261,12 +308,12 @@ pub struct GenerationWithDetectionTask { impl GenerationWithDetectionTask { pub fn new( - request_id: Uuid, + trace_id: TraceId, request: GenerationWithDetectionHttpRequest, headers: HeaderMap, ) -> Self { Self { - request_id, + trace_id, model_id: request.model_id, prompt: request.prompt, detectors: request.detectors, @@ -279,8 +326,8 @@ impl GenerationWithDetectionTask { /// Task for the /api/v2/text/detection/content endpoint #[derive(Debug)] pub struct TextContentDetectionTask { - /// Request unique identifier - pub request_id: Uuid, + /// Unique identifier of request trace + pub trace_id: TraceId, /// Content to run detection on pub content: String, @@ -294,12 +341,12 @@ pub struct TextContentDetectionTask { impl TextContentDetectionTask { pub fn new( - request_id: Uuid, + trace_id: TraceId, request: TextContentDetectionHttpRequest, headers: HeaderMap, ) -> Self { Self { - request_id, + trace_id, content: request.content, detectors: request.detectors, headers, @@ -310,8 +357,8 @@ impl TextContentDetectionTask { /// Task for the /api/v1/text/task/detection/context endpoint #[derive(Debug)] pub struct ContextDocsDetectionTask { - /// Request unique identifier - pub request_id: Uuid, + /// Unique identifier of request trace + pub trace_id: TraceId, /// Content to run detection on pub content: String, @@ -330,9 +377,9 @@ pub struct ContextDocsDetectionTask { } impl ContextDocsDetectionTask { - pub fn new(request_id: Uuid, request: ContextDocsHttpRequest, headers: HeaderMap) -> Self { + pub fn new(trace_id: TraceId, request: ContextDocsHttpRequest, headers: HeaderMap) -> Self { Self { - request_id, + trace_id, content: request.content, context_type: request.context_type, context: request.context, @@ -342,11 +389,38 @@ impl ContextDocsDetectionTask { } } +/// Task for the /api/v2/text/detection/chat endpoint +#[derive(Debug)] +pub struct ChatDetectionTask { + /// Request unique identifier + pub trace_id: TraceId, + + /// Detectors configuration + pub detectors: HashMap, + + // Messages to run detection on + pub messages: Vec, + + // Headermap + pub headers: HeaderMap, +} + +impl ChatDetectionTask { + pub fn new(trace_id: TraceId, request: ChatDetectionHttpRequest, headers: HeaderMap) -> Self { + Self { + trace_id, + detectors: request.detectors, + messages: request.messages, + headers, + } + } +} + /// Task for the /api/v2/text/detection/generated endpoint #[derive(Debug)] pub struct DetectionOnGenerationTask { - /// Request unique identifier - pub request_id: Uuid, + /// Unique identifier of request trace + pub trace_id: TraceId, /// User prompt to be sent to the LLM pub prompt: String, @@ -363,12 +437,12 @@ pub struct DetectionOnGenerationTask { impl DetectionOnGenerationTask { pub fn new( - request_id: Uuid, + trace_id: TraceId, request: DetectionOnGeneratedHttpRequest, headers: HeaderMap, ) -> Self { Self { - request_id, + trace_id, prompt: request.prompt, generated_text: request.generated_text, detectors: request.detectors, @@ -380,7 +454,7 @@ impl DetectionOnGenerationTask { #[allow(dead_code)] #[derive(Debug)] pub struct StreamingClassificationWithGenTask { - pub request_id: Uuid, + pub trace_id: TraceId, pub model_id: String, pub inputs: String, pub guardrails_config: GuardrailsConfig, @@ -389,9 +463,9 @@ pub struct StreamingClassificationWithGenTask { } impl StreamingClassificationWithGenTask { - pub fn new(request_id: Uuid, request: GuardrailsHttpRequest, headers: HeaderMap) -> Self { + pub fn new(trace_id: TraceId, request: GuardrailsHttpRequest, headers: HeaderMap) -> Self { Self { - request_id, + trace_id, model_id: request.model_id, inputs: request.inputs, guardrails_config: request.guardrail_config.unwrap_or_default(), @@ -401,6 +475,26 @@ impl StreamingClassificationWithGenTask { } } +#[derive(Debug)] +pub struct ChatCompletionsDetectionTask { + /// Unique identifier of request trace + pub trace_id: TraceId, + /// Chat completion request + pub request: ChatCompletionsRequest, + // Headermap + pub headers: HeaderMap, +} + +impl ChatCompletionsDetectionTask { + pub fn new(trace_id: TraceId, request: ChatCompletionsRequest, headers: HeaderMap) -> Self { + Self { + trace_id, + request, + headers, + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/orchestrator/chat_completions_detection.rs b/src/orchestrator/chat_completions_detection.rs new file mode 100644 index 00000000..dd06fab0 --- /dev/null +++ b/src/orchestrator/chat_completions_detection.rs @@ -0,0 +1,20 @@ +use tracing::{info, instrument}; + +use super::{ChatCompletionsDetectionTask, Error, Orchestrator}; +use crate::clients::openai::{ChatCompletionsResponse, OpenAiClient}; + +impl Orchestrator { + #[instrument(skip_all, fields(trace_id = ?task.trace_id, headers = ?task.headers))] + pub async fn handle_chat_completions_detection( + &self, + task: ChatCompletionsDetectionTask, + ) -> Result { + info!("handling chat completions detection task"); + let client = self + .ctx + .clients + .get_as::("chat_generation") + .expect("chat_generation client not found"); + Ok(client.chat_completions(task.request, task.headers).await?) + } +} diff --git a/src/orchestrator/streaming.rs b/src/orchestrator/streaming.rs index 0e738d83..861ab354 100644 --- a/src/orchestrator/streaming.rs +++ b/src/orchestrator/streaming.rs @@ -22,14 +22,17 @@ use std::{collections::HashMap, pin::Pin, sync::Arc, time::Duration}; use aggregator::Aggregator; use axum::http::HeaderMap; use futures::{future::try_join_all, Stream, StreamExt, TryStreamExt}; - use tokio::sync::{broadcast, mpsc}; use tokio_stream::wrappers::{BroadcastStream, ReceiverStream}; use tracing::{debug, error, info, instrument}; use super::{get_chunker_ids, Context, Error, Orchestrator, StreamingClassificationWithGenTask}; use crate::{ - clients::detector::ContentAnalysisRequest, + clients::{ + chunker::{tokenize_whole_doc_stream, ChunkerClient, DEFAULT_CHUNKER_ID}, + detector::ContentAnalysisRequest, + GenerationClient, TextContentsDetectorClient, + }, models::{ ClassifiedGeneratedTextStreamResult, DetectorParams, GuardrailsTextGenerationParameters, InputWarning, InputWarningReason, TextGenTokenClassificationResults, @@ -39,8 +42,7 @@ use crate::{ unary::{input_detection_task, tokenize}, UNSUITABLE_INPUT_MESSAGE, }, - pb::caikit::runtime::chunkers, - pb::caikit_data_model::nlp::ChunkerTokenizationStreamResult, + pb::{caikit::runtime::chunkers, caikit_data_model::nlp::ChunkerTokenizationStreamResult}, }; pub type Chunk = ChunkerTokenizationStreamResult; @@ -48,20 +50,20 @@ pub type Detections = Vec; impl Orchestrator { /// Handles streaming tasks. - #[instrument(name = "stream_handler", skip_all)] + #[instrument(skip_all, fields(trace_id = task.trace_id.to_string(), model_id = task.model_id, headers = ?task.headers))] pub async fn handle_streaming_classification_with_gen( &self, task: StreamingClassificationWithGenTask, ) -> ReceiverStream> { + info!(config = ?task.guardrails_config, "starting task"); + let ctx = self.ctx.clone(); - let request_id = task.request_id; + let trace_id = task.trace_id; let model_id = task.model_id; let params = task.text_gen_parameters; let input_text = task.inputs; let headers = task.headers; - info!(%request_id, config = ?task.guardrails_config, "starting task"); - // Create response channel #[allow(clippy::type_complexity)] let (response_tx, response_rx): ( @@ -86,7 +88,7 @@ impl Orchestrator { { Ok(result) => result, Err(error) => { - error!(%request_id, %error, "task failed"); + error!(%trace_id, %error, "task failed"); let _ = response_tx.send(Err(error)).await; return; } @@ -94,7 +96,7 @@ impl Orchestrator { } _ => None, }; - debug!(?input_detections); + debug!(?input_detections); // TODO: metrics if let Some(mut input_detections) = input_detections { // Detected HAP/PII // Do tokenization to get input_token_count @@ -104,7 +106,7 @@ impl Orchestrator { { Ok(result) => result, Err(error) => { - error!(%request_id, %error, "task failed"); + error!(%trace_id, %error, "task failed"); let _ = response_tx.send(Err(error)).await; return; } @@ -139,7 +141,7 @@ impl Orchestrator { { Ok(generation_stream) => generation_stream, Err(error) => { - error!(%request_id, %error, "task failed"); + error!(%trace_id, %error, "task failed"); let _ = response_tx.send(Err(error)).await; return; } @@ -169,7 +171,7 @@ impl Orchestrator { { Ok(result_rx) => result_rx, Err(error) => { - error!(%request_id, %error, "task failed"); + error!(%trace_id, %error, "task failed"); let _ = error_tx.send(error.clone()); let _ = response_tx.send(Err(error)).await; return; @@ -181,19 +183,19 @@ impl Orchestrator { loop { tokio::select! { Ok(error) = error_rx.recv() => { - error!(%request_id, %error, "task failed"); - debug!(%request_id, "sending error to client and terminating"); + error!(%trace_id, %error, "task failed"); + debug!(%trace_id, "sending error to client and terminating"); let _ = response_tx.send(Err(error)).await; return; }, result = result_rx.recv() => { match result { Some(result) => { - debug!(%request_id, ?result, "sending result to client"); + debug!(%trace_id, ?result, "sending result to client"); let _ = response_tx.send(result).await; }, None => { - info!(%request_id, "task completed: stream closed"); + info!(%trace_id, "task completed: stream closed"); break; }, } @@ -206,10 +208,10 @@ impl Orchestrator { // No output detectors, forward generation results to response channel tokio::spawn(async move { while let Some(result) = generation_stream.next().await { - debug!(%request_id, ?result, "sending result to client"); + debug!(%trace_id, ?result, "sending result to client"); let _ = response_tx.send(result).await; } - debug!(%request_id, "task completed: stream closed"); + debug!(%trace_id, "task completed: stream closed"); }); } } @@ -230,10 +232,11 @@ async fn streaming_output_detection_task( error_tx: broadcast::Sender, headers: HeaderMap, ) -> Result>, Error> { + debug!(?detectors, "creating chunk broadcast streams"); + // Create generation broadcast stream let (generation_tx, generation_rx) = broadcast::channel(1024); - debug!("creating chunk broadcast streams"); let chunker_ids = get_chunker_ids(ctx, detectors)?; // Create a map of chunker_id->chunk_broadcast_stream // This is to enable fan-out of chunk streams to potentially multiple detectors that use the same chunker. @@ -265,16 +268,25 @@ async fn streaming_output_detection_task( debug!("spawning detection tasks"); let mut detection_streams = Vec::with_capacity(detectors.len()); for (detector_id, detector_params) in detectors.iter() { + // Create a mutable copy of the parameters, so that we can modify it based on processing + let mut detector_params = detector_params.clone(); let detector_id = detector_id.to_string(); - let chunker_id = ctx.config.get_chunker_id(&detector_id).unwrap(); + let chunker_id = ctx + .config + .get_chunker_id(&detector_id) + .expect("chunker id is not found"); // Get the detector config // TODO: Add error handling - let detector_config = ctx.config.detectors.get(&detector_id).unwrap(); + let detector_config = ctx + .config + .detectors + .get(&detector_id) + .expect("detector config not found"); // Get the default threshold to use if threshold is not provided by the user let default_threshold = detector_config.default_threshold; - let threshold = detector_params.threshold().unwrap_or(default_threshold); + let threshold = detector_params.pop_threshold().unwrap_or(default_threshold); // Create detection stream let (detector_tx, detector_rx) = mpsc::channel(1024); @@ -287,6 +299,7 @@ async fn streaming_output_detection_task( tokio::spawn(detection_task( ctx.clone(), detector_id.clone(), + detector_params, threshold, detector_tx, chunk_rx, @@ -320,6 +333,7 @@ async fn generation_broadcast_task( generation_tx: broadcast::Sender, error_tx: broadcast::Sender, ) { + debug!("forwarding response stream"); let mut error_rx = error_tx.subscribe(); loop { tokio::select! { @@ -349,16 +363,19 @@ async fn generation_broadcast_task( /// Wraps a unary detector service to make it streaming. /// Consumes chunk broadcast stream, sends unary requests to a detector service, /// and sends chunk + responses to detection stream. -#[instrument(skip_all)] +#[allow(clippy::too_many_arguments)] +#[instrument(skip_all, fields(detector_id))] async fn detection_task( ctx: Arc, detector_id: String, + detector_params: DetectorParams, threshold: f64, detector_tx: mpsc::Sender<(Chunk, Detections)>, mut chunk_rx: broadcast::Receiver, error_tx: broadcast::Sender, headers: HeaderMap, ) { + debug!(threshold, "starting task"); let mut error_rx = error_tx.subscribe(); loop { @@ -378,12 +395,14 @@ async fn detection_task( debug!("empty chunk, skipping detector request."); break; } else { - let request = ContentAnalysisRequest::new(contents.clone()); + let request = ContentAnalysisRequest::new(contents.clone(), detector_params.clone()); let headers = headers.clone(); debug!(%detector_id, ?request, "sending detector request"); - match ctx - .detector_client - .text_contents(&detector_id, request, headers) + let client = ctx + .clients + .get_as::(&detector_id) + .unwrap_or_else(|| panic!("text contents detector client not found for {}", detector_id)); + match client.text_contents(&detector_id, request, headers) .await .map_err(|error| Error::DetectorRequestFailed { id: detector_id.clone(), error }) { Ok(response) => { @@ -424,7 +443,7 @@ async fn detection_task( /// Opens bi-directional stream to a chunker service /// with generation stream input and returns chunk broadcast stream. -#[instrument(skip_all)] +#[instrument(skip_all, fields(chunker_id))] async fn chunk_broadcast_task( ctx: Arc, chunker_id: String, @@ -432,7 +451,7 @@ async fn chunk_broadcast_task( error_tx: broadcast::Sender, ) -> Result, Error> { // Consume generation stream and convert to chunker input stream - debug!(%chunker_id, "creating chunker input stream"); + debug!("creating chunker input stream"); // NOTE: Text gen providers can return more than 1 token in single stream object. This can create // edge cases where the enumeration generated below may not line up with token / response boundaries. // So the more accurate way here might be to use `Tokens` object from response, but since that is an @@ -450,12 +469,27 @@ async fn chunk_broadcast_task( } }) .boxed(); - debug!(%chunker_id, "creating chunker output stream"); + debug!("creating chunker output stream"); let id = chunker_id.clone(); // workaround for StreamExt::map_err - let mut output_stream = ctx - .chunker_client - .bidi_streaming_tokenization_task_predict(&chunker_id, input_stream) - .await + + let response_stream = if chunker_id == DEFAULT_CHUNKER_ID { + info!("Using default whole doc chunker"); + let (response_tx, response_rx) = mpsc::channel(1); + // Spawn task to collect input stream + tokio::spawn(async move { + // NOTE: this will not resolve until the input stream is closed + let response = tokenize_whole_doc_stream(input_stream).await; + let _ = response_tx.send(response).await; + }); + Ok(ReceiverStream::new(response_rx).boxed()) + } else { + let client = ctx.clients.get_as::(&chunker_id).unwrap(); + client + .bidi_streaming_tokenization_task_predict(&chunker_id, input_stream) + .await + }; + + let mut output_stream = response_stream .map_err(|error| Error::ChunkerRequestFailed { id: chunker_id.clone(), error, @@ -466,7 +500,7 @@ async fn chunk_broadcast_task( }); // maps stream errors // Spawn task to consume output stream forward to broadcast channel - debug!(%chunker_id, "spawning chunker broadcast task"); + debug!("spawning chunker broadcast task"); let (chunk_tx, _) = broadcast::channel(1024); tokio::spawn({ let mut error_rx = error_tx.subscribe(); @@ -478,17 +512,17 @@ async fn chunk_broadcast_task( result = output_stream.next() => { match result { Some(Ok(chunk)) => { - debug!(%chunker_id, ?chunk, "received chunk"); + debug!(?chunk, "received chunk"); let _ = chunk_tx.send(chunk); }, Some(Err(error)) => { - error!(%chunker_id, %error, "chunker error, cancelling task"); + error!(%error, "chunker error, cancelling task"); let _ = error_tx.send(error); tokio::time::sleep(Duration::from_millis(5)).await; break; }, None => { - debug!(%chunker_id, "stream closed"); + debug!("stream closed"); break }, } @@ -501,6 +535,8 @@ async fn chunk_broadcast_task( } /// Sends generate stream request to a generation service. +#[allow(clippy::type_complexity)] +#[instrument(skip_all, fields(model_id))] async fn generate_stream( ctx: &Arc, model_id: String, @@ -511,8 +547,12 @@ async fn generate_stream( Pin> + Send>>, Error, > { - Ok(ctx - .generation_client + debug!(?params, "sending generate stream request"); + let client = ctx + .clients + .get_as::("generation") + .unwrap(); + Ok(client .generate_stream(model_id.clone(), text, params, headers) .await .map_err(|error| Error::GenerateRequestFailed { diff --git a/src/orchestrator/streaming/aggregator.rs b/src/orchestrator/streaming/aggregator.rs index 97c83d64..c1f6fc72 100644 --- a/src/orchestrator/streaming/aggregator.rs +++ b/src/orchestrator/streaming/aggregator.rs @@ -142,7 +142,9 @@ impl ResultActor { result.token_classification_results.output = Some(detections); if input_start_index == 0 { // Get input_token_count and seed from first generation message - let first = generations.first().unwrap(); + let first = generations + .first() + .expect("first element in classified generated text stream result not found"); result.input_token_count = first.input_token_count; result.seed = first.seed; // Get input_tokens from second generation message (if specified) diff --git a/src/orchestrator/unary.rs b/src/orchestrator/unary.rs index 4c30e012..bb049447 100644 --- a/src/orchestrator/unary.rs +++ b/src/orchestrator/unary.rs @@ -25,20 +25,28 @@ use futures::{ use tracing::{debug, error, info, instrument}; use super::{ - apply_masks, get_chunker_ids, Chunk, ClassificationWithGenTask, Context, + apply_masks, get_chunker_ids, ChatDetectionTask, Chunk, ClassificationWithGenTask, Context, ContextDocsDetectionTask, DetectionOnGenerationTask, Error, GenerationWithDetectionTask, Orchestrator, TextContentDetectionTask, }; use crate::{ - clients::detector::{ - ContentAnalysisRequest, ContentAnalysisResponse, ContextDocsDetectionRequest, ContextType, - GenerationDetectionRequest, + clients::{ + chunker::{tokenize_whole_doc, ChunkerClient, DEFAULT_CHUNKER_ID}, + detector::{ + ChatDetectionRequest, ContentAnalysisRequest, ContentAnalysisResponse, + ContextDocsDetectionRequest, ContextType, GenerationDetectionRequest, + TextChatDetectorClient, TextContentsDetectorClient, TextContextDocDetectorClient, + TextGenerationDetectorClient, + }, + openai::Message, + GenerationClient, }, models::{ - ClassifiedGeneratedTextResult, ContextDocsResult, DetectionOnGenerationResult, - DetectionResult, DetectorParams, GenerationWithDetectionResult, - GuardrailsTextGenerationParameters, InputWarning, InputWarningReason, - TextContentDetectionResult, TextGenTokenClassificationResults, TokenClassificationResult, + ChatDetectionResult, ClassifiedGeneratedTextResult, ContextDocsResult, + DetectionOnGenerationResult, DetectionResult, DetectorParams, + GenerationWithDetectionResult, GuardrailsTextGenerationParameters, InputWarning, + InputWarningReason, TextContentDetectionResult, TextGenTokenClassificationResults, + TokenClassificationResult, }, orchestrator::UNSUITABLE_INPUT_MESSAGE, pb::caikit::runtime::chunkers, @@ -48,15 +56,15 @@ const DEFAULT_STREAM_BUFFER_SIZE: usize = 5; impl Orchestrator { /// Handles unary tasks. - #[instrument(name = "unary_handler", skip_all)] + #[instrument(skip_all, fields(trace_id = ?task.trace_id, model_id = task.model_id, headers = ?task.headers))] pub async fn handle_classification_with_gen( &self, task: ClassificationWithGenTask, ) -> Result { let ctx = self.ctx.clone(); - let request_id = task.request_id; + let trace_id = task.trace_id; let headers = task.headers; - info!(%request_id, config = ?task.guardrails_config, "starting task"); + info!(config = ?task.guardrails_config, "handling classification with generation task"); let task_handle = tokio::spawn(async move { let input_text = task.inputs.clone(); let masks = task.guardrails_config.input_masks(); @@ -136,32 +144,31 @@ impl Orchestrator { match task_handle.await { // Task completed successfully Ok(Ok(result)) => { - debug!(%request_id, ?result, "sending result to client"); - info!(%request_id, "task completed"); + debug!(%trace_id, ?result, "sending result to client"); + info!(%trace_id, "task completed"); Ok(result) } // Task failed, return error propagated from child task that failed Ok(Err(error)) => { - error!(%request_id, %error, "task failed"); + error!(%trace_id, %error, "task failed"); Err(error) } // Task cancelled or panicked Err(error) => { let error = error.into(); - error!(%request_id, %error, "task failed"); + error!(%trace_id, %error, "task failed"); Err(error) } } } /// Handles the given generation task, followed by detections. + #[instrument(skip_all, fields(trace_id = ?task.trace_id, model_id = task.model_id, headers = ?task.headers))] pub async fn handle_generation_with_detection( &self, task: GenerationWithDetectionTask, ) -> Result { info!( - request_id = ?task.request_id, - model_id = %task.model_id, detectors = ?task.detectors, "handling generation with detection task" ); @@ -221,27 +228,25 @@ impl Orchestrator { Ok(Ok(result)) => Ok(result), // Task failed, return error propagated from child task that failed Ok(Err(error)) => { - error!(request_id = ?task.request_id, %error, "generation with detection unary task failed"); + error!(trace_id = ?task.trace_id, %error, "generation with detection unary task failed"); Err(error) } // Task cancelled or panicked Err(error) => { let error = error.into(); - error!(request_id = ?task.request_id, %error, "generation with detection unary task failed"); + error!(trace_id = ?task.trace_id, %error, "generation with detection unary task failed"); Err(error) } } } /// Handles detection on textual content + #[instrument(skip_all, fields(trace_id = ?task.trace_id, headers = ?task.headers))] pub async fn handle_text_content_detection( &self, task: TextContentDetectionTask, ) -> Result { - info!( - request_id = ?task.request_id, - "handling text content detection task" - ); + info!("handling text content detection task"); let ctx = self.ctx.clone(); let headers = task.headers; @@ -265,13 +270,19 @@ impl Orchestrator { let ctx = ctx.clone(); let detector_id = detector_id.clone(); let detector_params = detector_params.clone(); - let detector_config = ctx.config.detectors.get(&detector_id).unwrap(); + let detector_config = + ctx.config.detectors.get(&detector_id).unwrap_or_else(|| { + panic!("detector config not found for {}", detector_id) + }); let chunker_id = detector_config.chunker_id.as_str(); let default_threshold = detector_config.default_threshold; - let chunk = chunks.get(chunker_id).unwrap().clone(); + let chunk = chunks + .get(chunker_id) + .unwrap_or_else(|| panic!("chunk not found for {}", chunker_id)) + .clone(); let headers = headers.clone(); @@ -303,25 +314,25 @@ impl Orchestrator { Ok(Ok(result)) => Ok(result), // Task failed, return error propagated from child task that failed Ok(Err(error)) => { - error!(request_id = ?task.request_id, %error, "text content detection task failed"); + error!(trace_id = ?task.trace_id, %error, "text content detection task failed"); Err(error) } // Task cancelled or panicked Err(error) => { let error = error.into(); - error!(request_id = ?task.request_id, %error, "text content detection task failed"); + error!(trace_id = ?task.trace_id, %error, "text content detection task failed"); Err(error) } } } /// Handles context-related detections on textual content + #[instrument(skip_all, fields(trace_id = ?task.trace_id, headers = ?task.headers))] pub async fn handle_context_documents_detection( &self, task: ContextDocsDetectionTask, ) -> Result { info!( - request_id = ?task.request_id, detectors = ?task.detectors, "handling context documents detection task" ); @@ -368,25 +379,25 @@ impl Orchestrator { Ok(Ok(result)) => Ok(result), // Task failed, return error propagated from child task that failed Ok(Err(error)) => { - error!(request_id = ?task.request_id, %error, "context documents detection task failed"); + error!(trace_id = ?task.trace_id, %error, "context documents detection task failed"); Err(error) } // Task cancelled or panicked Err(error) => { let error = error.into(); - error!(request_id = ?task.request_id, %error, "context documents detection task failed"); + error!(trace_id = ?task.trace_id, %error, "context documents detection task failed"); Err(error) } } } /// Handles detections on generated text (without performing generation) + #[instrument(skip_all, fields(trace_id = ?task.trace_id, headers = ?task.headers))] pub async fn handle_generated_text_detection( &self, task: DetectionOnGenerationTask, ) -> Result { info!( - request_id = ?task.request_id, detectors = ?task.detectors, "handling detection on generated content task" ); @@ -431,13 +442,68 @@ impl Orchestrator { Ok(Ok(result)) => Ok(result), // Task failed, return error propagated from child task that failed Ok(Err(error)) => { - error!(request_id = ?task.request_id, %error, "detection on generated content task failed"); + error!(trace_id = ?task.trace_id, %error, "detection on generated content task failed"); Err(error) } // Task cancelled or panicked Err(error) => { let error = error.into(); - error!(request_id = ?task.request_id, %error, "detection on generated content task failed"); + error!(trace_id = ?task.trace_id, %error, "detection on generated content task failed"); + Err(error) + } + } + } + + /// Handles detections on chat messages (without performing generation) + #[instrument(skip_all, fields(trace_id = ?task.trace_id, headers = ?task.headers))] + pub async fn handle_chat_detection( + &self, + task: ChatDetectionTask, + ) -> Result { + info!( + detectors = ?task.detectors, + "handling detection on chat content task" + ); + let ctx = self.ctx.clone(); + let headers = task.headers; + + let task_handle = tokio::spawn(async move { + // call detection + let detections = try_join_all( + task.detectors + .iter() + .map(|(detector_id, detector_params)| { + let ctx = ctx.clone(); + let detector_id = detector_id.clone(); + let detector_params = detector_params.clone(); + let messages = task.messages.clone(); + let headers = headers.clone(); + async { + detect_for_chat(ctx, detector_id, detector_params, messages, headers) + .await + } + }) + .collect::>(), + ) + .await? + .into_iter() + .flatten() + .collect::>(); + + Ok(ChatDetectionResult { detections }) + }); + match task_handle.await { + // Task completed successfully + Ok(Ok(result)) => Ok(result), + // Task failed, return error propagated from child task that failed + Ok(Err(error)) => { + error!(%error, "detection task on chat failed"); + Err(error) + } + // Task cancelled or panicked + Err(error) => { + let error = error.into(); + error!(%error, "detection task on chat failed"); Err(error) } } @@ -453,6 +519,7 @@ pub async fn input_detection_task( masks: Option<&[(usize, usize)]>, headers: HeaderMap, ) -> Result>, Error> { + debug!(?detectors, "starting input detection"); let text_with_offsets = apply_masks(input_text, masks); let chunker_ids = get_chunker_ids(ctx, detectors)?; let chunks = chunk_task(ctx, chunker_ids, text_with_offsets).await?; @@ -468,6 +535,7 @@ async fn output_detection_task( generated_text: String, headers: HeaderMap, ) -> Result>, Error> { + debug!(detectors = ?detectors.keys(), "starting output detection"); let text_with_offsets = apply_masks(generated_text, None); let chunker_ids = get_chunker_ids(ctx, detectors)?; let chunks = chunk_task(ctx, chunker_ids, text_with_offsets).await?; @@ -483,6 +551,7 @@ async fn detection_task( chunks: HashMap>, headers: HeaderMap, ) -> Result, Error> { + debug!(detectors = ?detectors.keys(), "handling detection tasks"); // Spawn tasks for each detector let tasks = detectors .iter() @@ -532,6 +601,7 @@ async fn chunk_task( chunker_ids: Vec, text_with_offsets: Vec<(usize, String)>, ) -> Result>, Error> { + debug!(?chunker_ids, "handling chunk task"); // Spawn tasks for each chunker let tasks = chunker_ids .into_iter() @@ -549,36 +619,40 @@ async fn chunk_task( } /// Sends a request to a detector service and applies threshold. -#[instrument(skip_all)] +#[instrument(skip_all, fields(detector_id))] pub async fn detect( ctx: Arc, detector_id: String, default_threshold: f64, - detector_params: DetectorParams, + mut detector_params: DetectorParams, chunks: Vec, headers: HeaderMap, ) -> Result, Error> { let detector_id = detector_id.clone(); - let threshold = detector_params.threshold().unwrap_or(default_threshold); + let threshold = detector_params.pop_threshold().unwrap_or(default_threshold); let contents: Vec<_> = chunks.iter().map(|chunk| chunk.text.clone()).collect(); let response = if contents.is_empty() { // skip detector call as contents is empty Vec::default() } else { - let request = ContentAnalysisRequest::new(contents); - debug!(%detector_id, ?request, "sending detector request"); - ctx.detector_client + let request = ContentAnalysisRequest::new(contents, detector_params); + debug!(?request, "sending detector request"); + let client = ctx + .clients + .get_as::(&detector_id) + .unwrap(); + client .text_contents(&detector_id, request, headers) .await .map_err(|error| { - debug!(%detector_id, ?error, "error received from detector"); + debug!(?error, "error received from detector"); Error::DetectorRequestFailed { id: detector_id.clone(), error, } })? }; - debug!(%detector_id, ?response, "received detector response"); + debug!(?response, "received detector response"); if chunks.len() != response.len() { return Err(Error::Other(format!( "Detector {detector_id} did not return expected number of responses" @@ -604,29 +678,33 @@ pub async fn detect( /// Sends a request to a detector service and applies threshold. /// TODO: Cleanup by removing duplicate code and merging it with above `detect` function -#[instrument(skip_all)] +#[instrument(skip_all, fields(detector_id))] pub async fn detect_content( ctx: Arc, detector_id: String, default_threshold: f64, - detector_params: DetectorParams, + mut detector_params: DetectorParams, chunks: Vec, headers: HeaderMap, ) -> Result, Error> { let detector_id = detector_id.clone(); - let threshold = detector_params.threshold().unwrap_or(default_threshold); + let threshold = detector_params.pop_threshold().unwrap_or(default_threshold); let contents: Vec<_> = chunks.iter().map(|chunk| chunk.text.clone()).collect(); let response = if contents.is_empty() { // skip detector call as contents is empty Vec::default() } else { - let request = ContentAnalysisRequest::new(contents); - debug!(%detector_id, ?request, "sending detector request"); - ctx.detector_client + let request = ContentAnalysisRequest::new(contents, detector_params); + debug!(?request, threshold, "sending detector request"); + let client = ctx + .clients + .get_as::(&detector_id) + .unwrap(); + client .text_contents(&detector_id, request, headers) .await .map_err(|error| { - debug!(%detector_id, ?error, "error received from detector"); + debug!(?error, "error received from detector"); Error::DetectorRequestFailed { id: detector_id.clone(), error, @@ -657,17 +735,18 @@ pub async fn detect_content( } /// Calls a detector that implements the /api/v1/text/generation endpoint +#[instrument(skip_all, fields(detector_id))] pub async fn detect_for_generation( ctx: Arc, detector_id: String, - detector_params: DetectorParams, + mut detector_params: DetectorParams, prompt: String, generated_text: String, headers: HeaderMap, ) -> Result, Error> { let detector_id = detector_id.clone(); - let threshold = detector_params.threshold().unwrap_or( - detector_params.threshold().unwrap_or( + let threshold = detector_params.pop_threshold().unwrap_or( + detector_params.pop_threshold().unwrap_or( ctx.config .detectors .get(&detector_id) @@ -675,10 +754,19 @@ pub async fn detect_for_generation( .default_threshold, ), ); - let request = GenerationDetectionRequest::new(prompt.clone(), generated_text.clone()); - debug!(%detector_id, ?request, "sending generation detector request"); - let response = ctx - .detector_client + let request = + GenerationDetectionRequest::new(prompt.clone(), generated_text.clone(), detector_params); + debug!(threshold, ?request, "sending generation detector request"); + let client = ctx + .clients + .get_as::(&detector_id) + .unwrap_or_else(|| { + panic!( + "text generation detector client not found for {}", + detector_id + ) + }); + let response = client .text_generation(&detector_id, request, headers) .await .map(|results| { @@ -691,23 +779,65 @@ pub async fn detect_for_generation( id: detector_id.clone(), error, })?; - debug!(%detector_id, ?response, "received generation detector response"); + debug!(?response, "received generation detector response"); + Ok::, Error>(response) +} + +/// Calls a detector that implements the /api/v1/text/chat endpoint +pub async fn detect_for_chat( + ctx: Arc, + detector_id: String, + mut detector_params: DetectorParams, + messages: Vec, + headers: HeaderMap, +) -> Result, Error> { + let detector_id = detector_id.clone(); + let threshold = detector_params.pop_threshold().unwrap_or( + detector_params.pop_threshold().unwrap_or( + ctx.config + .detectors + .get(&detector_id) + .ok_or_else(|| Error::DetectorNotFound(detector_id.clone()))? + .default_threshold, + ), + ); + let request = ChatDetectionRequest::new(messages.clone(), detector_params); + debug!(%detector_id, ?request, "sending chat detector request"); + let client = ctx + .clients + .get_as::(&detector_id) + .unwrap(); + let response = client + .text_chat(&detector_id, request, headers) + .await + .map(|results| { + results + .into_iter() + .filter(|detection| detection.score > threshold) + .collect() + }) + .map_err(|error| Error::DetectorRequestFailed { + id: detector_id.clone(), + error, + })?; + debug!(%detector_id, ?response, "received chat detector response"); Ok::, Error>(response) } /// Calls a detector that implements the /api/v1/text/doc endpoint +#[instrument(skip_all, fields(detector_id))] pub async fn detect_for_context( ctx: Arc, detector_id: String, - detector_params: DetectorParams, + mut detector_params: DetectorParams, content: String, context_type: ContextType, context: Vec, headers: HeaderMap, ) -> Result, Error> { let detector_id = detector_id.clone(); - let threshold = detector_params.threshold().unwrap_or( - detector_params.threshold().unwrap_or( + let threshold = detector_params.pop_threshold().unwrap_or( + detector_params.pop_threshold().unwrap_or( ctx.config .detectors .get(&detector_id) @@ -715,10 +845,24 @@ pub async fn detect_for_context( .default_threshold, ), ); - let request = ContextDocsDetectionRequest::new(content, context_type, context, detector_params); - debug!(%detector_id, ?request, "sending context detector request"); - let response = ctx - .detector_client + let request = + ContextDocsDetectionRequest::new(content, context_type, context, detector_params.clone()); + debug!( + ?request, + threshold, + ?detector_params, + "sending context detector request" + ); + let client = ctx + .clients + .get_as::(&detector_id) + .unwrap_or_else(|| { + panic!( + "text context doc detector client not found for {}", + detector_id + ) + }); + let response = client .text_context_doc(&detector_id, request, headers) .await .map(|results| { @@ -736,7 +880,7 @@ pub async fn detect_for_context( } /// Sends request to chunker service. -#[instrument(skip_all)] +#[instrument(skip_all, fields(chunker_id))] pub async fn chunk( ctx: &Arc, chunker_id: String, @@ -744,16 +888,21 @@ pub async fn chunk( text: String, ) -> Result, Error> { let request = chunkers::ChunkerTokenizationTaskRequest { text }; - debug!(%chunker_id, ?request, "sending chunker request"); - let response = ctx - .chunker_client - .tokenization_task_predict(&chunker_id, request) - .await - .map_err(|error| Error::ChunkerRequestFailed { - id: chunker_id.clone(), - error, - })?; - debug!(%chunker_id, ?response, "received chunker response"); + debug!(?request, offset, "sending chunk request"); + let response = if chunker_id == DEFAULT_CHUNKER_ID { + tokenize_whole_doc(request) + } else { + let client = ctx.clients.get_as::(&chunker_id).unwrap(); + client + .tokenization_task_predict(&chunker_id, request) + .await + .map_err(|error| Error::ChunkerRequestFailed { + id: chunker_id.clone(), + error, + })? + }; + + debug!(?response, "received chunker response"); Ok(response .results .into_iter() @@ -765,11 +914,13 @@ pub async fn chunk( } /// Sends parallel requests to a chunker service. +#[instrument(skip_all, fields(chunker_id))] pub async fn chunk_parallel( ctx: &Arc, chunker_id: String, text_with_offsets: Vec<(usize, String)>, ) -> Result<(String, Vec), Error> { + debug!("sending parallel chunk requests"); let chunks = stream::iter(text_with_offsets) .map(|(offset, text)| { let ctx = ctx.clone(); @@ -791,13 +942,19 @@ pub async fn chunk_parallel( } /// Sends tokenize request to a generation service. +#[instrument(skip_all, fields(model_id))] pub async fn tokenize( ctx: &Arc, model_id: String, text: String, headers: HeaderMap, ) -> Result<(u32, Vec), Error> { - ctx.generation_client + debug!("sending tokenize request"); + let client = ctx + .clients + .get_as::("generation") + .unwrap(); + client .tokenize(model_id.clone(), text, headers) .await .map_err(|error| Error::TokenizeRequestFailed { @@ -807,6 +964,7 @@ pub async fn tokenize( } /// Sends generate request to a generation service. +#[instrument(skip_all, fields(model_id))] async fn generate( ctx: &Arc, model_id: String, @@ -814,7 +972,12 @@ async fn generate( params: Option, headers: HeaderMap, ) -> Result { - ctx.generation_client + debug!("sending generate request"); + let client = ctx + .clients + .get_as::("generation") + .unwrap(); + client .generate(model_id.clone(), text, params, headers) .await .map_err(|error| Error::GenerateRequestFailed { @@ -832,37 +995,20 @@ mod tests { clients::{ self, detector::{ContentAnalysisResponse, GenerationDetectionRequest}, - ChunkerClient, DetectorClient, GenerationClient, TgisClient, + ClientMap, GenerationClient, TgisClient, }, config::{DetectorConfig, OrchestratorConfig}, - models::{DetectionResult, EvidenceObj, FinishReason}, + models::{DetectionResult, EvidenceObj, FinishReason, THRESHOLD_PARAM}, pb::fmaas::{ BatchedGenerationRequest, BatchedGenerationResponse, GenerationRequest, GenerationResponse, StopReason, }, }; - async fn get_test_context( - gen_client: GenerationClient, - chunker_client: Option, - detector_client: Option, - ) -> Context { - let chunker_client = chunker_client.unwrap_or_default(); - let detector_client = detector_client.unwrap_or_default(); - - Context { - generation_client: gen_client, - chunker_client, - detector_client, - config: OrchestratorConfig::default(), - } - } - // Test for TGIS generation with default parameter #[tokio::test] async fn test_tgis_generate_with_default_params() { - // Initialize a mock object from `TgisClient` - let mut mock_client = TgisClient::faux(); + let mut tgis_client = TgisClient::faux(); let sample_text = String::from("sample text"); let text_gen_model_id = String::from("test-llm-id-1"); @@ -899,13 +1045,15 @@ mod tests { }; // Construct a behavior for the mock object - faux::when!(mock_client.generate(expected_generate_req_args, HeaderMap::new())) + faux::when!(tgis_client.generate(expected_generate_req_args, HeaderMap::new())) .once() // TODO: Add with_args .then_return(Ok(client_generation_response)); - let mock_generation_client = GenerationClient::tgis(mock_client.clone()); + let generation_client = GenerationClient::tgis(tgis_client.clone()); - let ctx = Arc::new(get_test_context(mock_generation_client, None, None).await); + let mut clients = ClientMap::new(); + clients.insert("generation".into(), generation_client); + let ctx = Arc::new(Context::new(OrchestratorConfig::default(), clients)); // Test request formulation and response processing is as expected assert_eq!( @@ -925,8 +1073,8 @@ mod tests { /// 2. detections below the threshold are not returned to the client. #[tokio::test] async fn test_handle_detection_task() { - let mock_generation_client = GenerationClient::tgis(TgisClient::faux()); - let mut mock_detector_client = DetectorClient::faux(); + let generation_client = GenerationClient::tgis(TgisClient::faux()); + let mut detector_client = TextContentsDetectorClient::faux(); let detector_id = "mocked_hap_detector"; let threshold = 0.5; @@ -934,7 +1082,7 @@ mod tests { let first_sentence = "I don't like potatoes.".to_string(); let second_sentence = "I hate aliens.".to_string(); let mut detector_params = DetectorParams::new(); - detector_params.insert("threshold".into(), threshold.into()); + detector_params.insert(THRESHOLD_PARAM.into(), threshold.into()); let chunks = vec![ Chunk { offset: 0, @@ -957,9 +1105,12 @@ mod tests { token_count: None, }]; - faux::when!(mock_detector_client.text_contents( + faux::when!(detector_client.text_contents( detector_id, - ContentAnalysisRequest::new(vec![first_sentence.clone(), second_sentence.clone()]), + ContentAnalysisRequest::new( + vec![first_sentence.clone(), second_sentence.clone()], + DetectorParams::new() + ), HeaderMap::new(), )) .once() @@ -984,12 +1135,14 @@ mod tests { }], ])); - let ctx: Context = - get_test_context(mock_generation_client, None, Some(mock_detector_client)).await; + let mut clients = ClientMap::new(); + clients.insert("generation".into(), generation_client); + clients.insert(detector_id.into(), detector_client); + let ctx = Arc::new(Context::new(OrchestratorConfig::default(), clients)); assert_eq!( detect( - ctx.into(), + ctx, detector_id.to_string(), threshold, detector_params, @@ -1005,14 +1158,14 @@ mod tests { /// This test checks if calls to detectors returning 503 are being propagated in the orchestrator response. #[tokio::test] async fn test_detect_when_detector_returns_503() { - let mock_generation_client = GenerationClient::tgis(TgisClient::faux()); - let mut mock_detector_client = DetectorClient::faux(); + let generation_client = GenerationClient::tgis(TgisClient::faux()); + let mut detector_client = TextContentsDetectorClient::faux(); let detector_id = "mocked_503_detector"; let sentence = "This call will return a 503.".to_string(); let threshold = 0.5; let mut detector_params = DetectorParams::new(); - detector_params.insert("threshold".into(), threshold.into()); + detector_params.insert(THRESHOLD_PARAM.into(), threshold.into()); let chunks = vec![Chunk { offset: 0, text: sentence.clone(), @@ -1027,9 +1180,9 @@ mod tests { }, }; - faux::when!(mock_detector_client.text_contents( + faux::when!(detector_client.text_contents( detector_id, - ContentAnalysisRequest::new(vec![sentence.clone()]), + ContentAnalysisRequest::new(vec![sentence.clone()], DetectorParams::new()), HeaderMap::new(), )) .once() @@ -1038,12 +1191,14 @@ mod tests { message: "Service Unavailable".to_string(), })); - let ctx: Context = - get_test_context(mock_generation_client, None, Some(mock_detector_client)).await; + let mut clients = ClientMap::new(); + clients.insert("generation".into(), generation_client); + clients.insert(detector_id.into(), detector_client); + let ctx = Arc::new(Context::new(OrchestratorConfig::default(), clients)); assert_eq!( detect( - ctx.into(), + ctx, detector_id.to_string(), threshold, detector_params, @@ -1055,35 +1210,39 @@ mod tests { expected_response ); } + #[tokio::test] async fn test_handle_detection_task_with_whitespace() { - let mock_generation_client = GenerationClient::tgis(TgisClient::faux()); - let mut mock_detector_client = DetectorClient::faux(); + let generation_client = GenerationClient::tgis(TgisClient::faux()); + let mut detector_client = TextContentsDetectorClient::faux(); let detector_id = "mocked_hap_detector"; let threshold = 0.5; let first_sentence = "".to_string(); let mut detector_params = DetectorParams::new(); - detector_params.insert("threshold".into(), threshold.into()); + detector_params.insert(THRESHOLD_PARAM.into(), threshold.into()); let chunks = vec![Chunk { offset: 0, text: first_sentence.clone(), }]; - faux::when!(mock_detector_client.text_contents( + faux::when!(detector_client.text_contents( detector_id, - ContentAnalysisRequest::new(vec![first_sentence.clone()]), + ContentAnalysisRequest::new(vec![first_sentence.clone()], DetectorParams::new()), HeaderMap::new(), )) .once() .then_return(Ok(vec![vec![]])); - let ctx: Context = - get_test_context(mock_generation_client, None, Some(mock_detector_client)).await; + let mut clients = ClientMap::new(); + clients.insert("generation".into(), generation_client); + clients.insert(detector_id.into(), detector_client); + let ctx = Arc::new(Context::new(OrchestratorConfig::default(), clients)); + let expected_response_whitespace = vec![]; assert_eq!( detect( - ctx.into(), + ctx, detector_id.to_string(), threshold, detector_params, @@ -1095,18 +1254,18 @@ mod tests { expected_response_whitespace ); } - /// This test checks if calls to detectors for the /generation-detection endpoint are being handled appropriately. + #[tokio::test] async fn test_detect_for_generation() { - let mock_generation_client = GenerationClient::tgis(TgisClient::faux()); - let mut mock_detector_client = DetectorClient::faux(); + let generation_client = GenerationClient::tgis(TgisClient::faux()); + let mut detector_client = TextGenerationDetectorClient::faux(); let detector_id = "mocked_answer_relevance_detector"; let threshold = 0.5; let prompt = "What is the capital of Brazil?".to_string(); let generated_text = "The capital of Brazil is Brasilia.".to_string(); let mut detector_params = DetectorParams::new(); - detector_params.insert("threshold".into(), threshold.into()); + detector_params.insert(THRESHOLD_PARAM.into(), threshold.into()); let expected_response: Vec = vec![DetectionResult { detection_type: "relevance".to_string(), @@ -1123,9 +1282,13 @@ mod tests { ), }]; - faux::when!(mock_detector_client.text_generation( + faux::when!(detector_client.text_generation( detector_id, - GenerationDetectionRequest::new(prompt.clone(), generated_text.clone()), + GenerationDetectionRequest::new( + prompt.clone(), + generated_text.clone(), + DetectorParams::new() + ), HeaderMap::new(), )) .once() @@ -1144,9 +1307,10 @@ mod tests { ), }])); - let mut ctx: Context = - get_test_context(mock_generation_client, None, Some(mock_detector_client)).await; - + let mut clients = ClientMap::new(); + clients.insert("generation".into(), generation_client); + clients.insert(detector_id.into(), detector_client); + let mut ctx = Context::new(OrchestratorConfig::default(), clients); // add detector ctx.config.detectors.insert( detector_id.to_string(), @@ -1154,10 +1318,11 @@ mod tests { ..Default::default() }, ); + let ctx = Arc::new(ctx); assert_eq!( detect_for_generation( - ctx.into(), + ctx, detector_id.to_string(), detector_params, prompt, @@ -1170,11 +1335,10 @@ mod tests { ); } - /// This test checks if calls to detectors for the /generation-detection endpoint only return detections above the threshold. #[tokio::test] async fn test_detect_for_generation_below_threshold() { - let mock_generation_client = GenerationClient::tgis(TgisClient::faux()); - let mut mock_detector_client = DetectorClient::faux(); + let generation_client = GenerationClient::tgis(TgisClient::faux()); + let mut detector_client = TextGenerationDetectorClient::faux(); let detector_id = "mocked_answer_relevance_detector"; let threshold = 0.5; @@ -1182,13 +1346,17 @@ mod tests { let generated_text = "The most beautiful places can be found in Rio de Janeiro.".to_string(); let mut detector_params = DetectorParams::new(); - detector_params.insert("threshold".into(), threshold.into()); + detector_params.insert(THRESHOLD_PARAM.into(), threshold.into()); let expected_response: Vec = vec![]; - faux::when!(mock_detector_client.text_generation( + faux::when!(detector_client.text_generation( detector_id, - GenerationDetectionRequest::new(prompt.clone(), generated_text.clone()), + GenerationDetectionRequest::new( + prompt.clone(), + generated_text.clone(), + DetectorParams::new() + ), HeaderMap::new(), )) .once() @@ -1199,20 +1367,22 @@ mod tests { evidence: None, }])); - let mut ctx: Context = - get_test_context(mock_generation_client, None, Some(mock_detector_client)).await; - - // add mocked detector + let mut clients = ClientMap::new(); + clients.insert("generation".into(), generation_client); + clients.insert(detector_id.into(), detector_client); + let mut ctx = Context::new(OrchestratorConfig::default(), clients); + // add detector ctx.config.detectors.insert( detector_id.to_string(), DetectorConfig { ..Default::default() }, ); + let ctx = Arc::new(ctx); assert_eq!( detect_for_generation( - ctx.into(), + ctx, detector_id.to_string(), detector_params, prompt, diff --git a/src/server.rs b/src/server.rs index f88f861d..e7bd12d6 100644 --- a/src/server.rs +++ b/src/server.rs @@ -40,22 +40,26 @@ use axum_extra::extract::WithRejection; use futures::{stream, Stream, StreamExt}; use hyper::body::Incoming; use hyper_util::rt::{TokioExecutor, TokioIo}; +use opentelemetry::trace::TraceContextExt; use rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig}; use tokio::{net::TcpListener, signal}; use tokio_rustls::TlsAcceptor; +use tokio_stream::wrappers::ReceiverStream; +use tower_http::trace::TraceLayer; use tower_service::Service; -use tracing::{debug, error, info, warn}; -use uuid::Uuid; +use tracing::{debug, error, info, instrument, warn, Span}; +use tracing_opentelemetry::OpenTelemetrySpanExt; use webpki::types::{CertificateDer, PrivateKeyDer}; use crate::{ - health::HealthCheckProbeParams, - models, + clients::openai::{ChatCompletionsRequest, ChatCompletionsResponse}, + models::{self, InfoParams, InfoResponse}, orchestrator::{ - self, ClassificationWithGenTask, ContextDocsDetectionTask, DetectionOnGenerationTask, - GenerationWithDetectionTask, Orchestrator, StreamingClassificationWithGenTask, - TextContentDetectionTask, + self, ChatCompletionsDetectionTask, ChatDetectionTask, ClassificationWithGenTask, + ContextDocsDetectionTask, DetectionOnGenerationTask, GenerationWithDetectionTask, + Orchestrator, StreamingClassificationWithGenTask, TextContentDetectionTask, }, + tracing_utils, }; const API_PREFIX: &str = r#"/api/v1/task"#; @@ -127,15 +131,23 @@ pub async fn run( // Configure mTLS if client CA is provided let client_auth = if tls_client_ca_cert_path.is_some() { info!("Configuring TLS trust certificate (mTLS) for incoming connections"); - let client_certs = load_certs(tls_client_ca_cert_path.as_ref().unwrap()); + let client_certs = load_certs( + tls_client_ca_cert_path + .as_ref() + .expect("error loading certs for mTLS"), + ); let mut client_auth_certs = RootCertStore::empty(); for client_cert in client_certs { // Should be only one - client_auth_certs.add(client_cert).unwrap(); + client_auth_certs + .add(client_cert.clone()) + .unwrap_or_else(|e| { + panic!("error adding client cert {:?}: {}", client_cert, e) + }); } WebPkiClientVerifier::builder(client_auth_certs.into()) .build() - .unwrap() + .unwrap_or_else(|e| panic!("error building client verifier: {}", e)) } else { WebPkiClientVerifier::no_client_auth() }; @@ -150,7 +162,7 @@ pub async fn run( } // (2b) Add main guardrails server routes - let app = Router::new() + let mut router = Router::new() .route( &format!("{}/classification-with-text-generation", API_PREFIX), post(classification_with_gen), @@ -170,6 +182,10 @@ pub async fn run( &format!("{}/detection/content", TEXT_API_PREFIX), post(detection_content), ) + .route( + &format!("{}/detection/chat", TEXT_API_PREFIX), + post(detect_chat), + ) .route( &format!("{}/detection/context", TEXT_API_PREFIX), post(detect_context_documents), @@ -177,8 +193,24 @@ pub async fn run( .route( &format!("{}/detection/generated", TEXT_API_PREFIX), post(detect_generated), - ) - .with_state(shared_state); + ); + + // If chat generation is configured, enable the chat completions detection endpoint. + if shared_state.orchestrator.config().chat_generation.is_some() { + info!("Enabling chat completions detection endpoint"); + router = router.route( + "/api/v2/chat/completions-detection", + post(chat_completions_detection), + ); + } + + let app = router.with_state(shared_state).layer( + TraceLayer::new_for_http() + .make_span_with(tracing_utils::incoming_request_span) + .on_request(tracing_utils::on_incoming_request) + .on_response(tracing_utils::on_outgoing_response) + .on_eos(tracing_utils::on_outgoing_eos), + ); // (2c) Generate main guardrails server handle based on whether TLS is needed let listener: TcpListener = TcpListener::bind(&http_addr) @@ -294,29 +326,23 @@ async fn health() -> Result { async fn info( State(state): State>, - Query(params): Query, -) -> Result { - match state.orchestrator.clients_health(params.probe).await { - Ok(client_health_info) => Ok(client_health_info), - Err(error) => { - error!( - "Unexpected internal error while checking client health info: {:?}", - error - ); - Err(error.into()) - } - } + Query(params): Query, +) -> Result, Error> { + let services = state.orchestrator.client_health(params.probe).await; + Ok(Json(InfoResponse { services })) } +#[instrument(skip_all, fields(model_id = ?request.model_id))] async fn classification_with_gen( State(state): State>, headers: HeaderMap, WithRejection(Json(request), _): WithRejection, Error>, ) -> Result { - let request_id = Uuid::new_v4(); + let trace_id = Span::current().context().span().span_context().trace_id(); + info!(?trace_id, "handling request"); request.validate()?; let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - let task = ClassificationWithGenTask::new(request_id, request, headers); + let task = ClassificationWithGenTask::new(trace_id, request, headers); match state .orchestrator .handle_classification_with_gen(task) @@ -327,6 +353,7 @@ async fn classification_with_gen( } } +#[instrument(skip_all, fields(model_id = ?request.model_id))] async fn generation_with_detection( State(state): State>, headers: HeaderMap, @@ -335,10 +362,11 @@ async fn generation_with_detection( Error, >, ) -> Result { - let request_id = Uuid::new_v4(); + let trace_id = Span::current().context().span().span_context().trace_id(); + info!(?trace_id, "handling request"); request.validate()?; let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - let task = GenerationWithDetectionTask::new(request_id, request, headers); + let task = GenerationWithDetectionTask::new(trace_id, request, headers); match state .orchestrator .handle_generation_with_detection(task) @@ -349,12 +377,14 @@ async fn generation_with_detection( } } +#[instrument(skip_all, fields(model_id = ?request.model_id))] async fn stream_classification_with_gen( State(state): State>, headers: HeaderMap, WithRejection(Json(request), _): WithRejection, Error>, ) -> Sse>> { - let request_id = Uuid::new_v4(); + let trace_id = Span::current().context().span().span_context().trace_id(); + info!(?trace_id, "handling request"); if let Err(error) = request.validate() { // Request validation failed, return stream with single error SSE event let error: Error = error.into(); @@ -367,7 +397,7 @@ async fn stream_classification_with_gen( ); } let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - let task = StreamingClassificationWithGenTask::new(request_id, request, headers); + let task = StreamingClassificationWithGenTask::new(trace_id, request, headers); let response_stream = state .orchestrator .handle_streaming_classification_with_gen(task) @@ -391,30 +421,34 @@ async fn stream_classification_with_gen( Sse::new(event_stream).keep_alive(KeepAlive::default()) } +#[instrument(skip_all)] async fn detection_content( State(state): State>, headers: HeaderMap, Json(request): Json, ) -> Result { - let request_id = Uuid::new_v4(); + let trace_id = Span::current().context().span().span_context().trace_id(); + info!(?trace_id, "handling request"); request.validate()?; let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - let task = TextContentDetectionTask::new(request_id, request, headers); + let task = TextContentDetectionTask::new(trace_id, request, headers); match state.orchestrator.handle_text_content_detection(task).await { Ok(response) => Ok(Json(response).into_response()), Err(error) => Err(error.into()), } } +#[instrument(skip_all)] async fn detect_context_documents( State(state): State>, headers: HeaderMap, WithRejection(Json(request), _): WithRejection, Error>, ) -> Result { - let request_id = Uuid::new_v4(); + let trace_id = Span::current().context().span().span_context().trace_id(); + info!(?trace_id, "handling request"); request.validate()?; let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - let task = ContextDocsDetectionTask::new(request_id, request, headers); + let task = ContextDocsDetectionTask::new(trace_id, request, headers); match state .orchestrator .handle_context_documents_detection(task) @@ -425,6 +459,23 @@ async fn detect_context_documents( } } +#[instrument(skip_all)] +async fn detect_chat( + State(state): State>, + headers: HeaderMap, + WithRejection(Json(request), _): WithRejection, Error>, +) -> Result { + let trace_id = Span::current().context().span().span_context().trace_id(); + request.validate_for_text()?; + let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); + let task = ChatDetectionTask::new(trace_id, request, headers); + match state.orchestrator.handle_chat_detection(task).await { + Ok(response) => Ok(Json(response).into_response()), + Err(error) => Err(error.into()), + } +} + +#[instrument(skip_all)] async fn detect_generated( State(state): State>, headers: HeaderMap, @@ -433,10 +484,11 @@ async fn detect_generated( Error, >, ) -> Result { - let request_id = Uuid::new_v4(); + let trace_id = Span::current().context().span().span_context().trace_id(); + info!(?trace_id, "handling request"); request.validate()?; let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - let task = DetectionOnGenerationTask::new(request_id, request, headers); + let task = DetectionOnGenerationTask::new(trace_id, request, headers); match state .orchestrator .handle_generated_text_detection(task) @@ -447,6 +499,33 @@ async fn detect_generated( } } +#[instrument(skip_all)] +async fn chat_completions_detection( + State(state): State>, + headers: HeaderMap, + WithRejection(Json(request), _): WithRejection, Error>, +) -> Result { + let trace_id = Span::current().context().span().span_context().trace_id(); + info!(?trace_id, "handling request"); + let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); + let task = ChatCompletionsDetectionTask::new(trace_id, request, headers); + match state + .orchestrator + .handle_chat_completions_detection(task) + .await + { + Ok(response) => match response { + ChatCompletionsResponse::Unary(response) => Ok(Json(response).into_response()), + ChatCompletionsResponse::Streaming(response_rx) => { + let response_stream = ReceiverStream::new(response_rx); + let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); + Ok(sse.into_response()) + } + }, + Err(error) => Err(error.into()), + } +} + /// Shutdown signal handler async fn shutdown_signal() { let ctrl_c = async { diff --git a/src/tracing_utils.rs b/src/tracing_utils.rs new file mode 100644 index 00000000..42da2169 --- /dev/null +++ b/src/tracing_utils.rs @@ -0,0 +1,373 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +use std::time::Duration; + +use axum::{extract::Request, http::HeaderMap, response::Response}; +use opentelemetry::{ + global, + metrics::MetricsError, + trace::{TraceContextExt, TraceError, TracerProvider}, + KeyValue, +}; +use opentelemetry_http::{HeaderExtractor, HeaderInjector}; +use opentelemetry_otlp::WithExportConfig; +use opentelemetry_sdk::{ + metrics::{ + reader::{DefaultAggregationSelector, DefaultTemporalitySelector}, + SdkMeterProvider, + }, + propagation::TraceContextPropagator, + runtime, + trace::{Config, Sampler}, + Resource, +}; +use tracing::{error, info, info_span, Span}; +use tracing_opentelemetry::{MetricsLayer, OpenTelemetrySpanExt}; +use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Layer}; + +use crate::args::{LogFormat, OtlpProtocol, TracingConfig}; + +#[derive(Debug, thiserror::Error)] +pub enum TracingError { + #[error("Error from tracing provider: {0}")] + TraceError(#[from] TraceError), + #[error("Error from metrics provider: {0}")] + MetricsError(#[from] MetricsError), +} + +fn service_config(tracing_config: TracingConfig) -> Config { + Config::default() + .with_resource(Resource::new(vec![KeyValue::new( + "service.name", + tracing_config.service_name, + )])) + .with_sampler(Sampler::AlwaysOn) +} + +/// Initializes an OpenTelemetry tracer provider with an OTLP export pipeline based on the +/// provided config. +fn init_tracer_provider( + tracing_config: TracingConfig, +) -> Result, TracingError> { + if let Some((protocol, endpoint)) = tracing_config.clone().traces { + Ok(Some( + match protocol { + OtlpProtocol::Grpc => opentelemetry_otlp::new_pipeline().tracing().with_exporter( + opentelemetry_otlp::new_exporter() + .tonic() + .with_endpoint(endpoint) + .with_timeout(Duration::from_secs(3)), + ), + OtlpProtocol::Http => opentelemetry_otlp::new_pipeline().tracing().with_exporter( + opentelemetry_otlp::new_exporter() + .http() + .with_http_client(reqwest::Client::new()) + .with_endpoint(endpoint) + .with_timeout(Duration::from_secs(3)), + ), + } + .with_trace_config(service_config(tracing_config)) + .install_batch(runtime::Tokio)?, + )) + } else if !tracing_config.quiet { + // We still need a tracing provider as long as we are logging in order to enable any + // trace-sensitive logs, such as any mentions of a request's trace_id. + Ok(Some( + opentelemetry_sdk::trace::TracerProvider::builder() + .with_config(service_config(tracing_config)) + .build(), + )) + } else { + Ok(None) + } +} + +/// Initializes an OpenTelemetry meter provider with an OTLP export pipeline based on the +/// provided config. +fn init_meter_provider( + tracing_config: TracingConfig, +) -> Result, TracingError> { + if let Some((protocol, endpoint)) = tracing_config.metrics { + Ok(Some( + match protocol { + OtlpProtocol::Grpc => opentelemetry_otlp::new_pipeline() + .metrics(runtime::Tokio) + .with_exporter( + opentelemetry_otlp::new_exporter() + .tonic() + .with_endpoint(endpoint), + ), + OtlpProtocol::Http => opentelemetry_otlp::new_pipeline() + .metrics(runtime::Tokio) + .with_exporter( + opentelemetry_otlp::new_exporter() + .http() + .with_http_client(reqwest::Client::new()) + .with_endpoint(endpoint), + ), + } + .with_resource(Resource::new(vec![KeyValue::new( + "service.name", + tracing_config.service_name, + )])) + .with_timeout(Duration::from_secs(10)) + .with_period(Duration::from_secs(3)) + .with_aggregation_selector(DefaultAggregationSelector::new()) + .with_temporality_selector(DefaultTemporalitySelector::new()) + .build()?, + )) + } else { + Ok(None) + } +} + +/// Initializes tracing for the orchestrator using the OpenTelemetry API/SDK and the `tracing` +/// crate. What telemetry is exported and to where is determined based on the provided config +pub fn init_tracing( + tracing_config: TracingConfig, +) -> Result Result<(), TracingError>, TracingError> { + let mut layers = Vec::new(); + global::set_text_map_propagator(TraceContextPropagator::new()); + + // TODO: Find a better way to only propagate errors from other crates + let filter = EnvFilter::try_from_default_env() + .unwrap_or(EnvFilter::new("INFO")) + .add_directive("ginepro=info".parse().unwrap()) + .add_directive("hyper=error".parse().unwrap()) + .add_directive("h2=error".parse().unwrap()) + .add_directive("trust_dns_resolver=error".parse().unwrap()) + .add_directive("trust_dns_proto=error".parse().unwrap()) + .add_directive("tower=error".parse().unwrap()) + .add_directive("tonic=error".parse().unwrap()) + .add_directive("reqwest=error".parse().unwrap()); + + // Set up tracing layer with OTLP exporter + let trace_provider = init_tracer_provider(tracing_config.clone())?; + if let Some(tracer_provider) = trace_provider.clone() { + global::set_tracer_provider(tracer_provider.clone()); + layers.push( + tracing_opentelemetry::layer() + .with_tracer(tracer_provider.tracer(tracing_config.service_name.clone())) + .boxed(), + ); + } + + // Set up metrics layer with OTLP exporter + let meter_provider = init_meter_provider(tracing_config.clone())?; + if let Some(meter_provider) = meter_provider.clone() { + global::set_meter_provider(meter_provider.clone()); + layers.push(MetricsLayer::new(meter_provider).boxed()); + } + + // Set up formatted layer for logging to stdout + // Because we use the `tracing` crate for logging, all logs are traces and will be exported + // to OTLP if `--otlp-export=traces` is set. + if !tracing_config.quiet { + match tracing_config.log_format { + LogFormat::Full => layers.push(tracing_subscriber::fmt::layer().boxed()), + LogFormat::Compact => layers.push(tracing_subscriber::fmt::layer().compact().boxed()), + LogFormat::Pretty => layers.push(tracing_subscriber::fmt::layer().pretty().boxed()), + LogFormat::JSON => layers.push( + tracing_subscriber::fmt::layer() + .json() + .flatten_event(true) + .boxed(), + ), + } + } + + let subscriber = tracing_subscriber::registry().with(filter).with(layers); + tracing::subscriber::set_global_default(subscriber).unwrap(); + + if let Some(traces) = tracing_config.traces { + info!( + "OTLP tracing enabled: Exporting {} to {}", + traces.0, traces.1 + ); + } else { + info!("OTLP traces export disabled") + } + + if let Some(metrics) = tracing_config.metrics { + info!( + "OTLP metrics enabled: Exporting {} to {}", + metrics.0, metrics.1 + ); + } else { + info!("OTLP metrics export disabled") + } + + if !tracing_config.quiet { + info!( + "Stdout logging enabled with format {}", + tracing_config.log_format + ); + } else { + info!("Stdout logging disabled"); // This will only be visible in traces + } + + Ok(move || { + global::shutdown_tracer_provider(); + if let Some(meter_provider) = meter_provider { + meter_provider + .shutdown() + .map_err(TracingError::MetricsError)?; + } + Ok(()) + }) +} + +pub fn incoming_request_span(request: &Request) -> Span { + info_span!( + "incoming_orchestrator_http_request", + request_method = request.method().to_string(), + request_path = request.uri().path().to_string(), + response_status_code = tracing::field::Empty, + request_duration_ms = tracing::field::Empty, + stream_response = tracing::field::Empty, + stream_response_event_count = tracing::field::Empty, + stream_response_error_count = tracing::field::Empty, + stream_response_duration_ms = tracing::field::Empty, + ) +} + +pub fn on_incoming_request(request: &Request, span: &Span) { + let _guard = span.enter(); + info!( + "incoming request to {} {} with trace_id {}", + request.method(), + request.uri().path(), + span.context().span().span_context().trace_id().to_string() + ); + info!( + monotonic_counter.incoming_request_count = 1, + request_method = request.method().as_str(), + request_path = request.uri().path() + ); +} + +pub fn on_outgoing_response(response: &Response, latency: Duration, span: &Span) { + let _guard = span.enter(); + span.record("response_status_code", response.status().as_u16()); + span.record("request_duration_ms", latency.as_millis()); + + info!( + "response {} for request with with trace_id {} generated in {} ms", + &response.status(), + span.context().span().span_context().trace_id().to_string(), + latency.as_millis() + ); + + // On every response + info!( + monotonic_counter.handled_request_count = 1, + response_status = response.status().as_u16(), + request_duration = latency.as_millis() + ); + info!( + histogram.service_request_duration = latency.as_millis(), + response_status = response.status().as_u16() + ); + + if response.status().is_server_error() { + // On every server error (HTTP 5xx) response + info!( + monotonic_counter.server_error_response_count = 1, + response_status = response.status().as_u16(), + request_duration = latency.as_millis() + ); + } else if response.status().is_client_error() { + // On every client error (HTTP 4xx) response + info!( + monotonic_counter.client_error_response_count = 1, + response_status = response.status().as_u16(), + request_duration = latency.as_millis() + ); + } else if response.status().is_success() { + // On every successful (HTTP 2xx) response + info!( + monotonic_counter.success_response_count = 1, + response_status = response.status().as_u16(), + request_duration = latency.as_millis() + ); + } else { + error!( + "unexpected response status code: {}", + response.status().as_u16() + ); + } +} + +pub fn on_outgoing_eos(trailers: Option<&HeaderMap>, stream_duration: Duration, span: &Span) { + let _guard = span.enter(); + + span.record("stream_response", true); + span.record("stream_response_duration_ms", stream_duration.as_millis()); + + info!( + "stream response for request with trace_id {} closed after {} ms with trailers: {:?}", + span.context().span().span_context().trace_id().to_string(), + stream_duration.as_millis(), + trailers + ); + info!( + monotonic_counter.service_stream_response_count = 1, + stream_duration = stream_duration.as_millis() + ); + info!(monotonic_histogram.service_stream_response_duration = stream_duration.as_millis()); +} + +/// Injects the `traceparent` header into the header map from the current tracing span context. +/// Also injects empty `tracestate` header by default. This can be used to propagate +/// vendor-specific trace context. +/// Used by both gRPC and HTTP requests since `tonic::Metadata` uses `http::HeaderMap`. +/// See https://www.w3.org/TR/trace-context/#trace-context-http-headers-format. +pub fn with_traceparent_header(headers: HeaderMap) -> HeaderMap { + let mut headers = headers.clone(); + let ctx = Span::current().context(); + global::get_text_map_propagator(|propagator| { + // Injects current `traceparent` (and by default empty `tracestate`) + propagator.inject_context(&ctx, &mut HeaderInjector(&mut headers)) + }); + headers +} + +/// Extracts the `traceparent` header from an HTTP response's headers and uses it to set the current +/// tracing span context (i.e. use `traceparent` as parent to the current span). +/// Defaults to using the current context when no `traceparent` is found. +/// See https://www.w3.org/TR/trace-context/#trace-context-http-headers-format. +pub fn trace_context_from_http_response(response: &reqwest::Response) { + let ctx = global::get_text_map_propagator(|propagator| { + // Returns the current context if no `traceparent` is found + propagator.extract(&HeaderExtractor(response.headers())) + }); + Span::current().set_parent(ctx); +} + +/// Extracts the `traceparent` header from a gRPC response's metadata and uses it to set the current +/// tracing span context (i.e. use `traceparent` as parent to the current span). +/// Defaults to using the current context when no `traceparent` is found. +/// See https://www.w3.org/TR/trace-context/#trace-context-http-headers-format. +pub fn trace_context_from_grpc_response(response: &tonic::Response) { + let ctx = global::get_text_map_propagator(|propagator| { + let metadata = response.metadata().clone(); + // Returns the current context if no `traceparent` is found + propagator.extract(&HeaderExtractor(&metadata.into_headers())) + }); + Span::current().set_parent(ctx); +} diff --git a/tests/test.config.yaml b/tests/test.config.yaml index 6ca749db..91749356 100644 --- a/tests/test.config.yaml +++ b/tests/test.config.yaml @@ -11,8 +11,9 @@ chunkers: port: 8085 detectors: test_detector: + type: text_contents service: - hostname: https://localhost/api/v1/text/contents + hostname: localhost port: 8000 chunker_id: test_chunker default_threshold: 0.5