diff --git a/docs/user-guides/detections-api-integration.md b/docs/user-guides/detections-api-integration.md new file mode 100644 index 000000000..ced33b6fa --- /dev/null +++ b/docs/user-guides/detections-api-integration.md @@ -0,0 +1,1201 @@ +# Detections API Integration for NeMo Guardrails + +## Overview + +This integration enables NeMo Guardrails to communicate with external detector services that implement the Detections API v1/text/contents protocol, providing a standardized interface for content safety checking without requiring detector logic within NeMo. + +**Key Features:** +- **Protocol-agnostic architecture**: Base interface pattern supports multiple detector API protocols (Detections API, KServe V1, future APIs) +- **Configuration-driven**: Add/remove detectors via ConfigMap updates only +- **Service-based detection**: Detectors run as independent microservices with rich metadata +- **Extensible design**: Add support for new API protocols by implementing two methods (request builder and response parser) +- **No code duplication**: Common HTTP, error handling, and orchestration logic shared across all detector types +- **Parallel execution**: All detectors run concurrently using asyncio.gather() for optimal performance +- **System error separation**: Distinguishes content violations from infrastructure failures (timeouts, HTTP errors) + +## Architecture + + User Input → NeMo Guardrails → Detections API Detector Services → vLLM (if safe) → Response + +**Components:** +- **NeMo Guardrails** (CPU) - Orchestration and flow control +- **Detections API Detectors** (CPU/GPU) - Content safety services implementing v1/text/contents protocol (this guide demonstrates Granite Guardian HAP as an example) +- **vLLM** (GPU) - LLM inference + +### Design: Base Interface Pattern + +This integration introduces a base interface architecture that eliminates code duplication when supporting multiple detector API protocols. + +**File Structure:** +``` +nemoguardrails/library/detector_clients/ +├── base.py # BaseDetectorClient interface (shared logic) +├── detections_api.py # Detections API v1/text/contents client +├── actions.py # NeMo action functions +└── __init__.py # Python package marker +``` + +**Why This Design:** + +Traditional approach would duplicate HTTP logic, error handling, and orchestration for each new API protocol. The base interface isolates what varies (request/response formats) from what stays constant (HTTP communication, error handling). + +**What's Shared (in base.py):** +- HTTP session management with connection pooling +- Authentication header handling (per-detector and global fallback) +- Timeout and error handling +- Standard `DetectorResult` model + +**What's API-Specific (in detections_api.py):** +- Request format: `{"contents": [text], "detector_params": {}}` +- Response parsing: Nested array structure `[[{detection1}, detection2}]]` +- Detection aggregation logic (multiple detections per text) +- Threshold and filtering logic + +**Benefits:** +- Add new API protocol = implement 2 methods (`build_request`, `parse_response`) +- No code changes to add detectors (ConfigMap only) +- Same orchestration logic for all detector types +- Extensible for future protocols (OpenAI Moderation, Perspective API, etc.) + +## Prerequisites + +- OpenShift cluster with KServe installed +- Access to Quay.io or container registry for pulling images +- vLLM deployment for LLM inference (or alternative OpenAI-compatible endpoint) + +## Requirements + +**This integration communicates with external services implementing the Detections API v1/text/contents protocol.** + +The Detections API provides structured detection results with rich metadata (spans, categories, confidence scores) rather than raw model outputs. Services must implement the standardized request/response format described below. + +### API Contract + +This integration uses **Detections API v1/text/contents protocol**. + +**Protocol:** REST API with detector-specific routing via headers + +**Requirements:** +- Endpoint path: `/api/v1/text/contents` +- Request header: `detector-id` specifying which detector to invoke +- Request body: `{"contents": ["text"], "detector_params": {}}` +- Response: Nested array of detection objects `[[{detection1}, {detection2}, ...]]` + +**Request Format:** +```json +POST /api/v1/text/contents +Header: detector-id: granite-guardian-hap + +{ + "contents": ["text to analyze"], + "detector_params": {} +} +``` + +**Response Format:** +```json +[[ + { + "start": 0, + "end": 20, + "detection_type": "pii", + "detection": "EmailAddress", + "score": 0.95, + "text": "matching text span", + "evidence": {}, + "metadata": {} + } +]] +``` + +Each detection includes: +- `start`, `end`: Character span indices in input text +- `detection_type`: Broad category (pii, toxicity, etc.) +- `detection`: Specific detection class +- `score`: Confidence score (0.0-1.0) +- `text`: Detected text span + +## How It Works + +### Detection Flow + +1. User sends message to NeMo Guardrails via HTTP POST to `/v1/chat/completions` +2. NeMo loads configuration from ConfigMap and triggers input safety flow defined in `rails.co` +3. `detections_api_check_all_detectors()` action executes, running all configured detectors in parallel +4. For each detector: + - `DetectionsAPIClient` builds request: `{"contents": [text], "detector_params": {}}` + - HTTP POST sent to detector service with `detector-id` header + - Detector service processes text and returns structured detections + - Parser extracts detections from nested array response `[[...]]` +5. Each detection is evaluated: + - If `detection.score >= threshold`: Detection triggers blocking + - Multiple detections per text are supported + - Highest scoring detection determines overall score +6. Results aggregation: + - System errors (timeouts, connection failures): Request blocked, tracked in `unavailable_detectors` + - Content violations: Request blocked, tracked in `blocking_detectors` with full metadata + - All pass: Request proceeds to vLLM for generation +7. Response returned to user (blocked message or LLM-generated response) + +### Base Interface Pattern + +The integration uses object-oriented design to eliminate code duplication across different detector API protocols. + +**BaseDetectorClient (Abstract Class):** +```python +class BaseDetectorClient(ABC): + @abstractmethod + async def detect(text: str) -> DetectorResult + + @abstractmethod + def build_request(text: str) -> dict + + @abstractmethod + def parse_response(response: dict, http_status: int) -> DetectorResult + + # Shared implementations: + async def _call_endpoint(...) # HTTP communication + def _handle_error(...) # Error handling +``` + +**DetectionsAPIClient (Implementation):** +```python +class DetectionsAPIClient(BaseDetectorClient): + def build_request(text: str) -> dict: + # Detections API specific format + return {"contents": [text], "detector_params": {}} + + def parse_response(response: dict, http_status: int) -> DetectorResult: + # Parse [[{detection1}, {detection2}]] + # Apply threshold filtering + # Return standardized DetectorResult +``` + +**Adding New API Protocol:** + +To support a new detector API (e.g., OpenAI Moderation, Perspective API): +1. Create new client class inheriting from `BaseDetectorClient` +2. Implement `build_request()` for the API's request format +3. Implement `parse_response()` for the API's response format +4. Add `@action()` decorated functions in `actions.py` that use the new client +5. Reuse all HTTP, auth, error handling from base class + +### Detection Logic + +**Multiple Detections Handling:** + +Detections API services can return multiple detections for a single text (e.g., two email addresses, one SSN). The parser: +1. Extracts all detections from nested array structure +2. Filters detections by threshold: `score >= threshold` +3. Blocks if **ANY** detection meets threshold (fail-safe approach) +4. Returns highest score as primary score +5. Includes all detection details in metadata for auditing + +**Example:** +``` +Input: "Email me at test@example.com or call 555-1234" + +Response: [[ + {detection: "EmailAddress", score: 0.99}, + {detection: "PhoneNumber", score: 0.85} +]] + +With threshold=0.5: +- Both detections >= 0.5 +- Content blocked +- Primary score: 0.99 (highest) +- Label: "pii:EmailAddress" (highest scoring detection) +- Metadata includes both detections +``` + +**Score Aggregation:** +- `score`: Highest individual detection score +- `metadata.detection_count`: Number of detections above threshold +- `metadata.individual_scores`: All scores for analysis + +### Error Handling + +The system distinguishes between infrastructure errors and content violations. + +### System Error Handling + +The system distinguishes between **content violations** (actual detections) and **system errors** (infrastructure failures like timeouts, HTTP errors, configuration issues). + +**Behavior:** +- System errors tracked separately in `unavailable_detectors` list +- Requests with system errors are blocked (fail-safe approach) +- Clear error messages indicate which detectors are unavailable vs which found violations + +**System Error Labels:** +`ERROR`, `HTTP_ERROR`, `TIMEOUT`, `NOT_FOUND`, `VALIDATION_ERROR`, `INVALID_RESPONSE`, `CONFIG_ERROR` + +This separation enables better operational monitoring and clearer user feedback. +**System Errors:** +- HTTP errors (404, 422, 500, 503) +- Network timeouts +- Invalid response formats +- Result: `allowed=False`, `label="ERROR"` or `"TIMEOUT"` +- Tracked in `unavailable_detectors` list +- User message: "Service temporarily unavailable" + +**Content Violations:** +- Successful detection with score >= threshold +- Result: `allowed=False`, `label="{type}:{detection}"` +- Tracked in `blocking_detectors` list with full metadata +- User message: Details which detectors blocked and scores + +**Multiple Detector Failures:** + +When running multiple detectors, the system provides comprehensive feedback showing all blocking detectors and any unavailable services, enabling both user communication and operational monitoring. + +## Deployment Guide + +### Prerequisites + +- OpenShift cluster with KServe installed +- Namespace: `` (this guide uses examples with placeholder) +- Access to Quay.io for pulling images +- vLLM or other OpenAI-compatible LLM endpoint for generation + +**This integration requires external Detections API services to be deployed.** + +Services must implement the v1/text/contents protocol with the request/response format described in the Requirements section. + +### Deployment Options + +**Option A: Using TrustyAI Guardrails Detectors (Recommended)** + +Deploy detectors from the [guardrails-detectors repository](https://github.com/trustyai-explainability/guardrails-detectors) which provides production-ready HuggingFace-based detectors implementing the Detections API protocol. + +**Option B: Deploy Your Own Detections API Service** + +Implement a service that exposes `/api/v1/text/contents` endpoint following the API contract. Refer to the guardrails-detectors repository for reference implementations. + +This guide demonstrates Option A with Granite Guardian HAP detector. + +### Step 1: Deploy Granite Guardian HAP Detector + +Granite Guardian requires model storage via MinIO (S3-compatible object storage running in-cluster) and uses a PVC-based approach to download and serve the model. + +**Why MinIO:** KServe expects S3-compatible storage for models. MinIO provides this locally without external dependencies, enabling disconnected cluster deployments. + +#### Deploy Model Storage and MinIO + +Create `granite-guardian-storage.yaml`: +```yaml +apiVersion: v1 +kind: Service +metadata: + name: minio-guardrails-guardian +spec: + ports: + - name: minio-client-port + port: 9000 + protocol: TCP + targetPort: 9000 + selector: + app: minio-guardrails-guardian +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: guardrails-models-claim-guardian +spec: + accessModes: + - ReadWriteOnce + volumeMode: Filesystem + resources: + requests: + storage: 100Gi +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: guardrails-container-deployment-guardian + labels: + app: minio-guardrails-guardian +spec: + replicas: 1 + selector: + matchLabels: + app: minio-guardrails-guardian + template: + metadata: + labels: + app: minio-guardrails-guardian + maistra.io/expose-route: 'true' + name: minio-guardrails-guardian + spec: + volumes: + - name: model-volume + persistentVolumeClaim: + claimName: guardrails-models-claim-guardian + initContainers: + - name: download-model + image: quay.io/rgeada/llm_downloader:latest + securityContext: + fsGroup: 1001 + command: + - bash + - -c + - | + model="ibm-granite/granite-guardian-3.0-2b" + echo "Starting download of ${model}" + /tmp/venv/bin/huggingface-cli download $model --local-dir /mnt/models/huggingface/$(basename $model) + echo "Download complete!" + resources: + limits: + memory: "2Gi" + cpu: "2" + volumeMounts: + - mountPath: "/mnt/models/" + name: model-volume + containers: + - args: + - server + - /models + env: + - name: MINIO_ACCESS_KEY + value: THEACCESSKEY + - name: MINIO_SECRET_KEY + value: THESECRETKEY + image: quay.io/trustyai/modelmesh-minio-examples:latest + name: minio + securityContext: + allowPrivilegeEscalation: false + capabilities: + drop: + - ALL + seccompProfile: + type: RuntimeDefault + volumeMounts: + - mountPath: "/models/" + name: model-volume +--- +apiVersion: v1 +kind: Secret +metadata: + name: aws-connection-minio-data-connection-guardrails-guardian + labels: + opendatahub.io/dashboard: 'true' + opendatahub.io/managed: 'true' + annotations: + opendatahub.io/connection-type: s3 + openshift.io/display-name: Minio Data Connection +data: + AWS_ACCESS_KEY_ID: VEhFQUNDRVNTS0VZ + AWS_DEFAULT_REGION: dXMtc291dGg= + AWS_S3_BUCKET: aHVnZ2luZ2ZhY2U= + AWS_S3_ENDPOINT: aHR0cDovL21pbmlvLWd1YXJkcmFpbHMtZ3VhcmRpYW46OTAwMA== + AWS_SECRET_ACCESS_KEY: VEhFU0VDUkVUS0VZ +type: Opaque +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + name: user-one +--- +kind: RoleBinding +apiVersion: rbac.authorization.k8s.io/v1 +metadata: + name: user-one-view +subjects: + - kind: ServiceAccount + name: user-one +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: ClusterRole + name: view +``` + +Deploy: +```bash +oc apply -f granite-guardian-storage.yaml -n +``` + +Monitor model download (takes 5-10 minutes for ~5GB model): +```bash +oc logs -f deployment/guardrails-container-deployment-guardian -n -c download-model +``` + +Wait for "Download complete!" message. + +Verify MinIO is running: +```bash +oc get pods -n | grep guardrails-container +``` + +Expected: Pod shows `2/2 Running` (init container completed, MinIO running) + +#### Deploy ServingRuntime for Granite Guardian + +Create `granite-guardian-runtime.yaml`: +```yaml +apiVersion: serving.kserve.io/v1alpha1 +kind: ServingRuntime +metadata: + name: guardrails-detector-runtime-guardian + annotations: + openshift.io/display-name: Guardrails Detector ServingRuntime for KServe + opendatahub.io/recommended-accelerators: '["nvidia.com/gpu"]' + labels: + opendatahub.io/dashboard: 'true' +spec: + annotations: + prometheus.io/port: '8000' + prometheus.io/path: '/metrics' + multiModel: false + supportedModelFormats: + - autoSelect: true + name: guardrails-detector-huggingface + containers: + - name: kserve-container + image: quay.io/rh-ee-mmisiura/guardrails-detector-huggingface:3d51741 + command: + - uvicorn + - app:app + args: + - "--workers" + - "1" + - "--host" + - "0.0.0.0" + - "--port" + - "8000" + - "--log-config" + - "/common/log_conf.yaml" + env: + - name: MODEL_DIR + value: /mnt/models + - name: HF_HOME + value: /tmp/hf_home + - name: DETECTOR_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + ports: + - containerPort: 8000 + protocol: TCP + resources: + requests: + memory: "18Gi" + cpu: "1" + limits: + memory: "20Gi" + cpu: "2" +``` + +Deploy: +```bash +oc apply -f granite-guardian-runtime.yaml -n +``` + +Verify: +```bash +oc get servingruntime -n | grep guardian +``` + +Expected: `guardrails-detector-runtime-guardian` appears in list + +#### Deploy Granite Guardian InferenceService + +Create `granite-guardian-isvc.yaml`: +```yaml +apiVersion: serving.kserve.io/v1beta1 +kind: InferenceService +metadata: + name: guardrails-detector-ibm-guardian + labels: + opendatahub.io/dashboard: 'true' + annotations: + openshift.io/display-name: guardrails-detector-ibm-guardian + security.opendatahub.io/enable-auth: 'true' + serving.knative.openshift.io/enablePassthrough: 'true' + sidecar.istio.io/inject: 'true' + sidecar.istio.io/rewriteAppHTTPProbers: 'true' + serving.kserve.io/deploymentMode: RawDeployment +spec: + predictor: + maxReplicas: 1 + minReplicas: 1 + model: + modelFormat: + name: guardrails-detector-huggingface + name: '' + runtime: guardrails-detector-runtime-guardian + storage: + key: aws-connection-minio-data-connection-guardrails-guardian + path: granite-guardian-3.0-2b +``` + +Deploy: +```bash +oc apply -f granite-guardian-isvc.yaml -n +``` + +Wait for predictor pod to start and load model (3-5 minutes): +```bash +oc get pods -n | grep guardrails-detector-ibm-guardian + +# Watch logs +oc logs -f -n $(oc get pods -n -l serving.kserve.io/inferenceservice=guardrails-detector-ibm-guardian -o name | head -1) -c kserve-container +``` + +Expected log output: +``` +Model type detected: causal_lm +Application startup complete. +Uvicorn running on http://0.0.0.0:8000 +``` + +Verify InferenceService is ready: +```bash +oc get inferenceservice guardrails-detector-ibm-guardian -n +``` + +Expected: `READY = True` + +**Note:** Granite Guardian runs on CPU by default. Inference takes 30-120 seconds per request. For production, consider deploying on GPU nodes or increasing timeout configuration. + +### Step 2: Deploy vLLM Inference Service + +vLLM uses a PVC-based approach to pre-download the Phi-3-mini model. This avoids runtime dependencies on HuggingFace and uses Red Hat's official AI Inference Server image. + +Create `vllm-inferenceservice.yml`: +```yaml +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: phi3-model-pvc +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: 20Gi +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: phi3-model-downloader +spec: + replicas: 1 + selector: + matchLabels: + app: phi3-downloader + template: + metadata: + labels: + app: phi3-downloader + spec: + initContainers: + - name: download-model + image: quay.io/rgeada/llm_downloader:latest + command: + - bash + - -c + - | + echo "Downloading Phi-3-mini" + /tmp/venv/bin/huggingface-cli download microsoft/Phi-3-mini-4k-instruct --local-dir /mnt/models/phi3-mini + echo "Download complete!" + volumeMounts: + - name: model-storage + mountPath: /mnt/models + containers: + - name: placeholder + image: registry.access.redhat.com/ubi9/ubi-minimal:latest + command: ["sleep", "infinity"] + volumes: + - name: model-storage + persistentVolumeClaim: + claimName: phi3-model-pvc +--- +apiVersion: serving.kserve.io/v1beta1 +kind: InferenceService +metadata: + name: vllm-phi3 +spec: + predictor: + containers: + - name: kserve-container + image: registry.redhat.io/rhaiis/vllm-cuda-rhel9:3 + args: + - --model=/mnt/models/phi3-mini + - --host=0.0.0.0 + - --port=8080 + - --served-model-name=phi3-mini + - --max-model-len=4096 + - --gpu-memory-utilization=0.7 + - --trust-remote-code + - --dtype=half + env: + - name: HF_HOME + value: /tmp/hf_cache + volumeMounts: + - name: model-storage + mountPath: /mnt/models + readOnly: true + resources: + limits: + nvidia.com/gpu: 1 + cpu: "6" + memory: "24Gi" + requests: + nvidia.com/gpu: 1 + cpu: "2" + memory: "8Gi" + volumes: + - name: model-storage + persistentVolumeClaim: + claimName: phi3-model-pvc +``` +Deploy: + +```bash +oc apply -f vllm-inferenceservice.yml -n +``` + +Monitor model download progress: + +```bash +oc logs -n -l app=phi3-downloader -c download-model -f +``` + +Wait for "Download complete!" message. The Phi-3-mini model is approximately 8GB and may take 3-5 minutes to download. +Verify vLLM is running: + +```bash +oc get inferenceservice vllm-phi3 -n +oc get pods -n | grep vllm-phi3 +``` + +Expected: `vllm-phi3` InferenceService shows `READY = True` and pod shows `1/1 Running`. + +### Step 3: Deploy NeMo Guardrails ConfigMap + +The ConfigMap contains detector configurations and flow definitions. Detectors are registered in the `detections_api_detectors` section with their endpoint URLs and detection parameters. + +Create `nemo-detections-configmap.yaml`: +```yaml +apiVersion: v1 +kind: ConfigMap +metadata: + name: nemo-detections-configmap +data: + config.yaml: | + rails: + config: + detections_api_detectors: + granite_hap: + inference_endpoint: "http://guardrails-detector-ibm-guardian-predictor..svc.cluster.local:8000/api/v1/text/contents" + detector_id: "granite-guardian-hap" + threshold: 0.5 + timeout: 120 + detector_params: {} + input: + flows: + - check_input_safety_detections_api + models: + - type: main + engine: vllm_openai + model: phi3-mini + parameters: + openai_api_base: http://vllm-phi3-predictor..svc.cluster.local:8080/v1 + openai_api_key: sk-dummy-key + instructions: + - type: general + content: | + You are a helpful AI assistant. + + rails.co: | + define bot blocked by detector + "Input blocked by content safety detectors" + + define bot output blocked by detector + "I apologize, but I cannot provide that response." + + define bot service unavailable + "Service temporarily unavailable" + + define flow check_input_safety_detections_api + $input_result = execute detections_api_check_all_detectors + + if $input_result.unavailable_detectors + bot service unavailable + stop + + if not $input_result.allowed + bot blocked by detector + stop + + define flow check_output_safety_detections_api + $output_result = execute detections_api_check_all_detectors + + if $output_result.unavailable_detectors + bot service unavailable + stop + + if not $output_result.allowed + bot output blocked by detector + stop +``` + +**Configuration Fields:** +- `inference_endpoint`: Full URL to detector's `/api/v1/text/contents` endpoint +- `detector_id`: Identifier sent in `detector-id` header (detector-specific) +- `threshold`: Minimum score to trigger blocking (0.0-1.0) +- `timeout`: Request timeout in seconds (increase for CPU-based detectors) +- `detector_params`: Optional detector-specific parameters (sent in request body) + +**Important:** +- Timeout should be 120+ seconds for CPU-based detectors like Granite Guardian +- Replace `` with your actual namespace +- `detector_id` must match what the detector service expects + +Deploy: +```bash +oc apply -f nemo-detections-configmap.yaml -n +``` + +Verify: +```bash +oc get configmap nemo-detections-configmap -n +``` + +### Step 4: Deploy NeMo Guardrails Server + +Create `nemo-deployment.yaml`: +```yaml +apiVersion: apps/v1 +kind: Deployment +metadata: + name: nemo-guardrails-server +spec: + replicas: 1 + selector: + matchLabels: + app: nemo-guardrails + template: + metadata: + labels: + app: nemo-guardrails + spec: + containers: + - name: nemo-guardrails + image: quay.io/rh-ee-stondapu/trustyai-nemo:latest + imagePullPolicy: Always + env: + - name: CONFIG_ID + value: production + - name: OPENAI_API_KEY + value: sk-dummy-key + - name: DETECTIONS_API_KEY + value: "your-global-token" + ports: + - containerPort: 8000 + volumeMounts: + - name: config-volume + mountPath: /app/config/production + resources: + requests: + cpu: "500m" + memory: "1Gi" + limits: + cpu: "2" + memory: "4Gi" + volumes: + - name: config-volume + configMap: + name: nemo-detections-configmap +--- +apiVersion: v1 +kind: Service +metadata: + name: nemo-guardrails-server +spec: + selector: + app: nemo-guardrails + ports: + - port: 8000 + targetPort: 8000 + type: ClusterIP +--- +apiVersion: route.openshift.io/v1 +kind: Route +metadata: + name: nemo-guardrails-server +spec: + port: + targetPort: 8000 + tls: + termination: edge + insecureEdgeTerminationPolicy: Allow + to: + kind: Service + name: nemo-guardrails-server +``` + +Deploy: +```bash +oc apply -f nemo-deployment.yaml -n +``` + +Get the external route URL: +```bash +YOUR_ROUTE="http://$(oc get route nemo-guardrails-server -n -o jsonpath='{.spec.host}')" +echo "NeMo Guardrails URL: $YOUR_ROUTE" +``` + +Verify NeMo is running: +```bash +oc get pods -n | grep nemo-guardrails-server +``` + +Expected: Pod shows `1/1 Running` + +Check logs to confirm detector loaded: +```bash +oc logs -n $(oc get pods -n -l app=nemo-guardrails -o name | head -1) +``` + +Expected log output should show: +``` +Configuration validated. Starting server... +Application startup complete. +Uvicorn running on http://0.0.0.0:8000 +``` + +No "Failed to register" errors should appear. + +## Testing + +### Unit Testing + +The integration includes **104 comprehensive unit tests** with **97%+ code coverage**. + +**Run tests:** +```bash +poetry run pytest tests/test_detector_clients_*.py -v + +# With coverage +poetry run pytest tests/test_detector_clients_*.py \ + --cov=nemoguardrails/library/detector_clients \ + --cov-report=term-missing +``` + +**Test coverage:** +- All error handling paths (timeouts, HTTP errors, malformed responses) +- System error classification and separation +- Parallel detector execution +- Edge cases (threshold boundaries, missing fields) + +### Testing on a cluster +Extract the NeMo route for testing: +```bash +YOUR_ROUTE="http://$(oc get route nemo-guardrails-server -n -o jsonpath='{.spec.host}')" +echo "Testing against: $YOUR_ROUTE" +``` + +### Test 1: Safe Content (Should Pass) + +What this tests: Verifies detectors allow safe content and LLM generates response. +```bash +curl -X POST $YOUR_ROUTE/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "config_id": "production", + "messages": [ + {"role": "user", "content": "Hello, how are you today?"} + ] + }' +``` + +**Expected Output:** +```json +{ + "messages":[ + { + "role":"assistant", + "content":"I am an AI, so I don't have feelings, but I'm ready and functioning optimally. How can I help you today?" + } + ] +} +``` + +Content passes all detectors and LLM generates helpful response. + +### Test 2: Jailbreak Detection (Should Block) + +What this tests: Verifies Granite Guardian detects and blocks jailbreak attempts. +```bash +curl -X POST $YOUR_ROUTE/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "config_id": "production", + "messages": [ + {"role": "user", "content": "How to hack into a system?"} + ] + }' +``` + +**Expected Output:** +```json +{ + "messages": [ + { + "role": "assistant", + "content": "Input blocked by content safety detectors" + } + ] +} +``` + +**Note:** This request takes 30-120 seconds due to Granite Guardian running on CPU. The detector identified multiple risk categories (harm, jailbreak, unethical_behavior) with high confidence scores. + +### Test 3: Verify Detector Communication + +What this tests: Confirms NeMo successfully communicates with Detections API endpoint. + +Check NeMo logs during Test 2: +```bash +oc logs -n $(oc get pods -n -l app=nemo-guardrails -o name | head -1) --tail=100 | grep "detections_api_check_all_detectors" +``` + +**Expected Log Output:** +``` +execute detections_api_check_all_detectors +# The result was {'allowed': False, 'reason': 'Blocked by 1 Detections API detector(s): granite_hap', ...} +``` + +Log shows detector was called, returned structured detections, and content was blocked. + +### Understanding Response Formats + +**Detections API Response Structure:** +```json +[[ + { + "start": 0, + "end": 25, + "detection_type": "causal_lm", + "detection": "causal_lm", + "score": 0.9985, + "sequence_classification": "jailbreak", + "text": "How to hack into a system?" + }, + { + "start": 0, + "end": 25, + "detection_type": "causal_lm", + "detection": "causal_lm", + "score": 0.9978, + "sequence_classification": "harm", + "text": "How to hack into a system?" + } +]] +``` + +**Key Fields:** +- `detection_type`: Broad category +- `detection`: Specific detection class +- `score`: Confidence (0.0-1.0) +- `sequence_classification`: Risk category identified + +**How Parser Handles Multiple Detections:** +1. Extracts all detections from nested array `[[...]]` +2. Filters by threshold: keeps detections where `score >= threshold` +3. If any detection meets threshold: `allowed = False` +4. Primary score: Highest individual detection score +5. Label format: `"{detection_type}:{detection}"` from highest scoring detection +6. All detections included in metadata for audit trail + +**Example with threshold=0.5:** +- Detection 1: jailbreak, score=0.998 → Triggers blocking +- Detection 2: harm, score=0.997 → Also triggers +- Result: `allowed=False`, `score=0.998`, `label="causal_lm:causal_lm"` +- Metadata contains both detections with individual scores + +## Adding New Detectors + +No code changes required. The system is fully configuration-driven. + +### Steps to Add a Detector + +1. **Deploy a detector service** implementing Detections API v1/text/contents protocol +2. **Determine the detector-id** required by the service +3. **Choose appropriate threshold** for your use case +4. **Add detector configuration** to NeMo ConfigMap +5. **Apply ConfigMap and restart** NeMo to load new detector + +### Example: Adding a New Detector + +This example shows adding a hypothetical toxicity detector to complement Granite Guardian. + +**Step 1: Deploy Detector Service** + +Follow the detector service's deployment instructions. For TrustyAI guardrails-detectors, use the repository's deployment files similar to Granite Guardian. + +**Step 2: Test Detector Endpoint** + +Identify the detector-id and test the endpoint directly: +```bash +# Port forward to detector service +oc port-forward -n svc/your-detector-predictor 8000:8000 + +# Test with sample content +curl -X POST http://localhost:8000/api/v1/text/contents \ + -H "detector-id: your-detector-id" \ + -H "Content-Type: application/json" \ + -d '{"contents": ["test content"], "detector_params": {}}' +``` + +Examine the response to understand: +- What `detector-id` value to use +- Detection score ranges +- What constitutes a detection (for threshold tuning) + +**Step 3: Add to ConfigMap** + +Edit `nemo-detections-configmap.yaml` and add your detector: +```yaml +detections_api_detectors: + granite_hap: + # ... existing detector ... + + your_detector: # Detector name (used in logs and error messages) + inference_endpoint: "http://your-detector-predictor..svc.cluster.local:8000/api/v1/text/contents" + detector_id: "your-detector-id" + threshold: 0.7 + timeout: 30 + detector_params: {} +``` + +**Step 4: Apply and Restart** +```bash +oc apply -f nemo-detections-configmap.yaml -n +oc rollout restart deployment/nemo-guardrails-server -n +``` + +**Step 5: Test New Detector** +```bash +curl -X POST $YOUR_ROUTE/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{"config_id": "production", "messages": [{"role": "user", "content": "content that triggers your detector"}]}' +``` + +Check logs to verify detector executed: +```bash +oc logs -n $(oc get pods -n -l app=nemo-guardrails -o name | head -1) --tail=50 | grep "your_detector" +``` + +### Determining Configuration Values + +**Threshold Selection:** +- Start with `0.5` (moderate sensitivity) +- Test with sample content +- Increase (e.g., 0.7) to reduce false positives +- Decrease (e.g., 0.3) to catch more potential issues +- Monitor blocking rates and adjust + +**Timeout Selection:** +- CPU-based detectors: 60-120 seconds +- GPU-based detectors: 10-30 seconds +- Network latency considerations: Add 5-10 seconds buffer +- Monitor actual response times in logs + +**detector_params:** +- Consult detector service documentation +- Used for detector-specific configuration +- Passed through to detector service in request body +- Example: `{"language": "en", "categories": ["pii", "toxicity"]}` + +## Resource Cleanup + +The integration uses a shared HTTP session for connection pooling. For proper resource cleanup during application shutdown: +```python +from nemoguardrails.library.detector_clients.base import cleanup_http_session + +# At application shutdown +await cleanup_http_session() +``` + +This prevents resource leaks by properly closing the aiohttp session. The function is idempotent and safe to call multiple times. + +## Authentication (Optional) + +Detections API services can be secured with authentication to restrict access. + +### Prerequisites for Authentication + +Authentication requires one of: +- Service Mesh (Istio) with Authorino (for OpenShift AI/OpenDataHub deployments) +- API Gateway with authentication capabilities +- Alternative authentication mechanism (OAuth proxy, etc.) + +### Enabling Authentication on Detector Services + +Authentication configuration depends on your detector deployment method and infrastructure. + +**For TrustyAI Guardrails Detectors with OpenShift AI:** + +Add authentication annotations to InferenceService: +```yaml +apiVersion: serving.kserve.io/v1beta1 +kind: InferenceService +metadata: + name: guardrails-detector-ibm-guardian + annotations: + security.opendatahub.io/enable-auth: 'true' + serving.kserve.io/deploymentMode: RawDeployment + serving.knative.openshift.io/enablePassthrough: 'true' + sidecar.istio.io/inject: 'true' +spec: + # ... rest of spec +``` + +**Note:** Authentication annotations vary by cluster infrastructure. Consult your cluster administrator for correct configuration. + +### Configuring NeMo Authentication to Detectors + +NeMo supports both global authentication tokens and per-detector tokens with automatic fallback. + +**Option 1: Global Token (All Detectors)** + +Set environment variable in NeMo deployment: +```yaml +env: + - name: CONFIG_ID + value: production + - name: DETECTIONS_API_KEY + value: "your-bearer-token" +``` + +All detectors without explicit `api_key` will use this token. + +**Option 2: Per-Detector Tokens** + +Specify in ConfigMap: +```yaml +detections_api_detectors: + granite_hap: + inference_endpoint: "..." + detector_id: "granite-guardian-hap" + api_key: "granite-specific-token" + threshold: 0.5 + + other_detector: + inference_endpoint: "..." + detector_id: "other-id" + # No api_key specified - falls back to DETECTIONS_API_KEY env var + threshold: 0.7 +``` + +**Token Priority:** Per-detector `api_key` → Global `DETECTIONS_API_KEY` env var → No authentication + +**Getting Tokens:** +```bash +# For OpenShift service accounts +oc sa get-token -n + +# For OpenShift AI secured services +oc whoami -t +``` diff --git a/nemoguardrails/library/detector_clients/__init__.py b/nemoguardrails/library/detector_clients/__init__.py new file mode 100644 index 000000000..3159bfe65 --- /dev/null +++ b/nemoguardrails/library/detector_clients/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemoguardrails/library/detector_clients/actions.py b/nemoguardrails/library/detector_clients/actions.py new file mode 100644 index 000000000..c3dbd6130 --- /dev/null +++ b/nemoguardrails/library/detector_clients/actions.py @@ -0,0 +1,322 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +NeMo action functions for Detections API integration. +""" + +import asyncio +import logging +from typing import Any, Dict, Optional + +from nemoguardrails.actions import action +from nemoguardrails.library.detector_clients.base import AggregatedDetectorResult, DetectorResult +from nemoguardrails.library.detector_clients.detections_api import DetectionsAPIClient + +log = logging.getLogger(__name__) + +""" System error labels indicate infrastructure/configuration issues, + not content violations. Detectors with these labels failed to execute + properly and should be treated as unavailable. """ +SYSTEM_ERROR_LABELS = { + "ERROR", + "HTTP_ERROR", + "TIMEOUT", + "NOT_FOUND", + "VALIDATION_ERROR", + "INVALID_RESPONSE", + "CONFIG_ERROR", +} + + +async def _run_detections_api_detector(detector_name: str, detector_config: Any, text: str) -> DetectorResult: + """ + Execute single Detections API detector. + + Internal helper function used by action functions. + + Args: + detector_name: Name of the detector + detector_config: DetectionsAPIConfig object + text: Input text to analyze + + Returns: + DetectorResult with detection outcome + """ + try: + client = DetectionsAPIClient(detector_config, detector_name) + except ValueError as e: + # Constructor validation failed (e.g., missing detector_id) + log.error(f"{detector_name} configuration error: {e}") + return DetectorResult( + allowed=False, + score=0.0, + reason=f"{detector_name} configuration error: {str(e)}", + label="ERROR", + detector=detector_name, + metadata={"error": str(e)}, + ) + + # detect() handles all runtime errors internally and always returns DetectorResult + result = await client.detect(text) + return result + + +@action() +async def detections_api_check_all_detectors( + context: Optional[Dict] = None, config: Optional[Any] = None, **kwargs +) -> Dict[str, Any]: + """ + Run all configured Detections API detectors in parallel. + + This is the main action function called by NeMo rails.co flows. + Automatically detects and checks the appropriate message type from context + (user_message for input guardrails, bot_message for output guardrails). + + Args: + context: NeMo context dict containing message content (user_message, bot_message, etc.) + config: NeMo config object + **kwargs: Additional keyword arguments + + Returns: + Dict representation of AggregatedDetectorResult + """ + + if context is None: + context = {} + + if not config: + config = context.get("config") + + if not config: + return AggregatedDetectorResult( + allowed=False, + reason="No configuration provided", + blocking_detectors=[], + allowing_detectors=[], + detector_count=0, + ).dict() + + message_sources = ["user_message", "bot_message"] + text = "" + + for source in message_sources: + if source in context: + message = context[source] + text = message.get("content", "") if isinstance(message, dict) else str(message) + if text: + log.debug(f"Checking {source} with Detections API detectors") + break + + if not text: + log.warning("No message content found in context for detection") + return AggregatedDetectorResult( + allowed=True, + reason="No message content found", + blocking_detectors=[], + allowing_detectors=[], + detector_count=0, + ).dict() + + if not hasattr(config, "rails") or not hasattr(config.rails, "config"): + log.warning("Configuration incomplete") + return AggregatedDetectorResult( + allowed=True, + reason="Configuration incomplete", + blocking_detectors=[], + allowing_detectors=[], + detector_count=0, + ).dict() + + detections_api_detectors = getattr(config.rails.config, "detections_api_detectors", {}) + + if not detections_api_detectors: + return AggregatedDetectorResult( + allowed=True, + reason="No Detections API detectors configured", + blocking_detectors=[], + allowing_detectors=[], + detector_count=0, + ).dict() + + log.info( + f"Running {len(detections_api_detectors)} Detections API detectors: {list(detections_api_detectors.keys())}" + ) + + detector_names = [] + tasks = [] + + for name, config_obj in detections_api_detectors.items(): + detector_names.append(name) + tasks.append(_run_detections_api_detector(name, config_obj, text)) + + # Gather all results + results = await asyncio.gather(*tasks, return_exceptions=True) + + system_errors = [] + content_blocks = [] + allowing = [] + + for detector_name, result in zip(detector_names, results): + if isinstance(result, Exception): + log.error(f"{detector_name} exception: {result}") + error_result = DetectorResult( + allowed=False, + score=0.0, + reason=f"Exception: {result}", + label="ERROR", + detector=detector_name, + metadata={"error": str(result)}, + ) + system_errors.append(error_result) + elif result.label in SYSTEM_ERROR_LABELS: + system_errors.append(result) + elif not result.allowed: + content_blocks.append(result) + else: + allowing.append(result) + + if system_errors: + unavailable = [e.detector for e in system_errors] + reason = f"System error: {len(system_errors)} Detections API detector(s) unavailable - {', '.join(unavailable)}" + log.warning(reason) + + return AggregatedDetectorResult( + allowed=False, + reason=reason, + unavailable_detectors=unavailable, + blocking_detectors=content_blocks, + allowing_detectors=allowing, + detector_count=len(detections_api_detectors), + ).dict() + + overall_allowed = len(content_blocks) == 0 + + if overall_allowed: + reason = f"Approved by all {len(allowing)} Detections API detectors" + else: + blocking_detector_names = [d.detector for d in content_blocks] + reason = ( + f"Blocked by {len(content_blocks)} Detections API detector(s): {', '.join(set(blocking_detector_names))}" + ) + + log.info(f"Detections API: {'ALLOWED' if overall_allowed else 'BLOCKED'}: {reason}") + + return AggregatedDetectorResult( + allowed=overall_allowed, + reason=reason, + blocking_detectors=content_blocks, + allowing_detectors=allowing, + detector_count=len(detections_api_detectors), + ).dict() + + +@action() +async def detections_api_check_detector( + context: Optional[Dict] = None, config: Optional[Any] = None, detector_name: str = "mock_pii", **kwargs +) -> Dict[str, Any]: + """ + Run specific Detections API detector by name. + + Automatically detects and checks the appropriate message type from context + (user_message for input guardrails, bot_message for output guardrails). + + Args: + context: NeMo context dict containing message content (user_message, bot_message, etc.) + config: NeMo config object + detector_name: Name of detector to run + **kwargs: Additional keyword arguments + + Returns: + Dict representation of DetectorResult + """ + if context is None: + context = {} + + if not config: + config = context.get("config") + + if not config: + return DetectorResult( + allowed=False, + score=0.0, + reason="No configuration provided", + label="NO_CONFIG", + detector=detector_name, + metadata={}, + ).dict() + + message_sources = ["user_message", "bot_message"] + text = "" + + for source in message_sources: + if source in context: + message = context[source] + text = message.get("content", "") if isinstance(message, dict) else str(message) + if text: + log.debug(f"Checking {source} with Detections API detectors") + break + + if not text: + log.warning("No message content found in context for detection") + return DetectorResult( + allowed=True, + score=0.0, + reason="No message content found", + label="NO_CONTENT", + detector=detector_name, + metadata={}, + ).dict() + + if not hasattr(config, "rails") or not hasattr(config.rails, "config"): + log.warning("Configuration incomplete") + return DetectorResult( + allowed=True, + score=0.0, + reason="Configuration incomplete", + label="CONFIG_INCOMPLETE", + detector=detector_name, + metadata={}, + ).dict() + + detections_api_detectors = getattr(config.rails.config, "detections_api_detectors", {}) + + if detector_name not in detections_api_detectors: + return DetectorResult( + allowed=True, + score=0.0, + reason=f"Detector '{detector_name}' not configured", + label="NOT_CONFIGURED", + detector=detector_name, + metadata={}, + ).dict() + + detector_config = detections_api_detectors[detector_name] + + if detector_config is None: + return DetectorResult( + allowed=True, + score=0.0, + reason=f"Detector '{detector_name}' has no configuration", + label="NONE", + detector=detector_name, + metadata={}, + ).dict() + + result = await _run_detections_api_detector(detector_name, detector_config, text) + + log.info(f"Detections API {detector_name}: {'allowed' if result.allowed else 'blocked'} (score={result.score:.3f})") + + return result.dict() diff --git a/nemoguardrails/library/detector_clients/base.py b/nemoguardrails/library/detector_clients/base.py new file mode 100644 index 000000000..8e5ba04b5 --- /dev/null +++ b/nemoguardrails/library/detector_clients/base.py @@ -0,0 +1,246 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Base interface for detector clients. +All detector implementations must inherit from this class. +""" + +import asyncio +import logging +import os +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +import aiohttp +from pydantic import BaseModel, Field + +log = logging.getLogger(__name__) + +# Global HTTP session for connection pooling +_http_session: Optional[aiohttp.ClientSession] = None +_session_lock = asyncio.Lock() + + +class DetectorResult(BaseModel): + """Standardized result from detector execution""" + + allowed: bool = Field(description="Whether content is allowed") + score: float = Field(description="Detection confidence score (0.0-1.0)") + reason: str = Field(description="Human-readable explanation") + label: str = Field(description="Detection label or category") + detector: str = Field(description="Detector name") + metadata: Optional[Dict[str, Any]] = Field(default=None, description="Additional detection metadata") + + +class AggregatedDetectorResult(BaseModel): + """Aggregated result from multiple detectors""" + + allowed: bool = Field(description="Whether content passed all detectors") + reason: str = Field(description="Summary of detection results") + blocking_detectors: List[DetectorResult] = Field(default_factory=list, description="Detectors that blocked content") + allowing_detectors: List[DetectorResult] = Field( + default_factory=list, description="Detectors that approved content" + ) + detector_count: int = Field(description="Total number of detectors run") + unavailable_detectors: Optional[List[str]] = Field( + default=None, description="Detectors that encountered system errors" + ) + + +class BaseDetectorClient(ABC): + """ + Abstract base class for all detector clients. + Defines the interface that all detector implementations must follow. + """ + + def __init__(self, config: Any, detector_name: str): + """ + Initialize detector client with configuration. + + Args: + config: Detector-specific configuration object + detector_name: Name of the detector for logging and error reporting + """ + self.config = config + self.detector_name = detector_name + self.endpoint = getattr(config, "inference_endpoint", "") + self.timeout = getattr(config, "timeout", 30) + self.api_key = getattr(config, "api_key", None) + + @abstractmethod + async def detect(self, text: str) -> DetectorResult: + """ + Main entry point for detection. + Orchestrates the detection flow: build request -> call endpoint -> parse response. + + Args: + text: Input text to analyze + + Returns: + DetectorResult with detection outcome + """ + pass + + @abstractmethod + def build_request(self, text: str) -> Dict[str, Any]: + """ + Build API-specific request payload. + + Args: + text: Input text to analyze + + Returns: + Request payload dict in API-specific format + """ + pass + + @abstractmethod + def parse_response(self, response: Any, http_status: int) -> DetectorResult: + """ + Parse API-specific response into standardized DetectorResult. + + Args: + response: API response data + http_status: HTTP status code from response + + Returns: + DetectorResult with parsed detection outcome + """ + pass + + async def _call_endpoint( + self, endpoint: str, payload: Dict[str, Any], timeout: int, headers: Optional[Dict[str, str]] = None + ) -> tuple[Any, int]: + """ + Make HTTP POST request to detector endpoint. + Shared implementation for all detector types. + + Args: + endpoint: API endpoint URL + payload: Request payload + timeout: Request timeout in seconds + headers: Optional HTTP headers + + Returns: + Tuple of (response_data, http_status_code) + + Raises: + Exception: On HTTP errors or timeouts + """ + global _http_session + + # Lazy session initialization + if _http_session is None: + async with _session_lock: + if _http_session is None: + _http_session = aiohttp.ClientSession() + + # Build headers + request_headers = {"Content-Type": "application/json"} + if headers: + request_headers.update(headers) + + # Add auth if configured (per-detector key or global env var) + token = self.api_key or os.getenv("DETECTIONS_API_KEY") + if token: + request_headers["Authorization"] = f"Bearer {token}" + + timeout_config = aiohttp.ClientTimeout(total=timeout) + + try: + async with _http_session.post( + endpoint, json=payload, headers=request_headers, timeout=timeout_config + ) as response: + http_status = response.status + + if http_status == 200: + response_data = await response.json() + return response_data, http_status + else: + error_text = await response.text() + raise Exception(f"HTTP {http_status}: {error_text}") + + except asyncio.TimeoutError: + raise Exception(f"Request timeout after {timeout}s") + except aiohttp.ClientError as e: + raise Exception(f"HTTP client error: {str(e)}") + + def _handle_error(self, error: Exception, detector_name: str) -> DetectorResult: + """ + Convert exceptions into DetectorResult with error state. + Shared error handling for all detector types. + + Args: + error: Exception that occurred + detector_name: Name of detector for error reporting + + Returns: + DetectorResult indicating system error (blocked state) + """ + error_message = str(error) + + # Check if it's an HTTP error + if error_message.startswith("HTTP "): + label = "HTTP_ERROR" + reason = f"{detector_name} service error: {error_message}" + elif "timeout" in error_message.lower(): + label = "TIMEOUT" + reason = f"{detector_name} timeout: {error_message}" + else: + label = "ERROR" + reason = f"{detector_name} error: {error_message}" + + log.error(f"{detector_name} error: {error}") + + return DetectorResult( + allowed=False, + score=0.0, + reason=reason, + label=label, + detector=detector_name, + metadata={"error": error_message}, + ) + + +async def cleanup_http_session(): + """ + Close the shared HTTP session and release resources. + + The global aiohttp.ClientSession is shared across all detector clients for + connection pooling and performance. This function properly closes the session + to prevent resource leaks during application shutdown. + + This function is idempotent - it can be called multiple times safely. + + Args: + None + + Returns: + None + + Raises: + None + + Note: + Should be called once during application shutdown. The session will be + automatically recreated on next detector call if needed. + """ + global _http_session + + if _http_session is not None: + await _http_session.close() + _http_session = None + log.info("Detections API HTTP session closed") diff --git a/nemoguardrails/library/detector_clients/detections_api.py b/nemoguardrails/library/detector_clients/detections_api.py new file mode 100644 index 000000000..890188c52 --- /dev/null +++ b/nemoguardrails/library/detector_clients/detections_api.py @@ -0,0 +1,254 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Detections API v1/text/contents client implementation. +Handles communication with FMS-style detection endpoints. +""" + +import logging +from typing import Any, Dict, List + +from nemoguardrails.library.detector_clients.base import BaseDetectorClient, DetectorResult + +log = logging.getLogger(__name__) + + +class DetectionsAPIClient(BaseDetectorClient): + """ + Client for Detections API v1/text/contents endpoint. + + Expected API format: + - Request: POST with detector-id header, {"contents": [text], "detector_params": {}} + - Response: [[{detection1}, {detection2}, ...]] + """ + + def __init__(self, config: Any, detector_name: str): + """ + Initialize Detections API client. + + Args: + config: DetectionsAPIConfig with endpoint, detector_id, threshold, etc. + """ + super().__init__(config, detector_name) + self.detector_id = getattr(config, "detector_id", "") + self.threshold = getattr(config, "threshold", 0.5) + self.detector_params = getattr(config, "detector_params", {}) + + if not self.detector_id: + raise ValueError("detector_id is required for DetectionsAPIClient") + + async def detect(self, text: str) -> DetectorResult: + """ + Run detection on input text. + + Args: + text: Input text to analyze + + Returns: + DetectorResult with detection outcome + """ + try: + payload = self.build_request(text) + headers = {"detector-id": self.detector_id} + + response_data, http_status = await self._call_endpoint( + endpoint=self.endpoint, payload=payload, timeout=self.timeout, headers=headers + ) + + result = self.parse_response(response_data, http_status) + + log.info( + f"{self.detector_name}: {'allowed' if result.allowed else 'blocked'} " + f"(score={result.score:.3f}, " + f"detections={result.metadata.get('detection_count', 0) if result.metadata else 0})" + ) + + return result + + except Exception as e: + return self._handle_error(e, self.detector_name) + + def build_request(self, text: str) -> Dict[str, Any]: + """ + Build Detections API request payload. + + Args: + text: Input text to analyze + + Returns: + Request dict: {"contents": [text], "detector_params": {...}} + """ + return {"contents": [text], "detector_params": self.detector_params} + + def parse_response(self, response: Any, http_status: int) -> DetectorResult: + """ + Parse Detections API response into DetectorResult. + + Response format: [[{detection1}, {detection2}, ...]] + Each detection: {start, end, text, detection_type, detection, score, evidence, metadata} + + Args: + response: API response data + http_status: HTTP status code + + Returns: + DetectorResult with parsed detection outcome + """ + if http_status != 200: + if http_status == 404: + label = "NOT_FOUND" + reason = f"{self.detector_name} detector not found" + elif http_status == 422: + label = "VALIDATION_ERROR" + reason = f"Invalid request to {self.detector_name}" + else: + label = "ERROR" + reason = f"HTTP {http_status} error from {self.detector_name}" + + return DetectorResult( + allowed=False, + score=0.0, + reason=reason, + label=label, + detector=self.detector_name, + metadata={"http_status": http_status}, + ) + + if not isinstance(response, list): + return DetectorResult( + allowed=False, + score=0.0, + reason="Invalid response format: expected list", + label="INVALID_RESPONSE", + detector=self.detector_name, + metadata={"response_type": type(response).__name__}, + ) + + detections = self._extract_detections_from_response(response) + + if not detections: + return DetectorResult( + allowed=True, + score=0.0, + reason="No detections found", + label="NONE", + detector=self.detector_name, + metadata={"detection_count": 0}, + ) + + filtered_detections = [d for d in detections if d.get("score", 0.0) >= self.threshold] + + if not filtered_detections: + return DetectorResult( + allowed=True, + score=max((d.get("score", 0.0) for d in detections), default=0.0), + reason=f"All detections below threshold {self.threshold}", + label="BELOW_THRESHOLD", + detector=self.detector_name, + metadata={ + "detection_count": 0, + "total_detections": len(detections), + "individual_scores": [d.get("score", 0.0) for d in detections], + "highest_detection": max(detections, key=lambda d: d.get("score", 0.0), default={}), + "detections": [{**d, "passed": d.get("score", 0.0) < self.threshold} for d in detections], + }, + ) + + highest_detection = self._get_highest_score_detection(filtered_detections) + highest_score = highest_detection.get("score", 0.0) + + detection_type = highest_detection.get("detection_type", "unknown") + detection_name = highest_detection.get("detection", "unknown") + label = f"{detection_type}:{detection_name}" + + reason = self._build_reason_message(filtered_detections) + + return DetectorResult( + allowed=False, + score=highest_score, + reason=reason, + label=label, + detector=self.detector_name, + metadata={ + "detection_count": len(filtered_detections), + "total_detections": len(detections), + "individual_scores": [d.get("score", 0.0) for d in detections], + "highest_detection": highest_detection, + "detections": [{**d, "passed": d.get("score", 0.0) < self.threshold} for d in detections], + }, + ) + + def _extract_detections_from_response(self, response: List[Any]) -> List[Dict[str, Any]]: + """ + Extract detections from nested array structure. + + Response format: [[{detection1}, {detection2}]] + + Args: + response: API response list + + Returns: + Flat list of detection dicts + """ + if not response: + return [] + + return response[0] + + def _get_highest_score_detection(self, detections: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Get detection with highest score. + + Args: + detections: List of detection dicts + + Returns: + Detection dict with highest score + """ + if not detections: + return {} + + return max(detections, key=lambda d: d.get("score", 0.0)) + + def _build_reason_message(self, detections: List[Dict[str, Any]]) -> str: + """ + Build human-readable reason message from detections. + + Args: + detections: List of detection dicts + + Returns: + Formatted reason string + """ + count = len(detections) + + if count == 0: + return "No detections found" + + if count == 1: + det = detections[0] + detection_type = det.get("detection_type", "unknown") + detection_name = det.get("detection", "unknown") + score = det.get("score", 0.0) + return f"Blocked by {detection_type}:{detection_name} (score={score:.2f})" + + detection_types = set(d.get("detection_type", "unknown") for d in detections) + highest = self._get_highest_score_detection(detections) + highest_score = highest.get("score", 0.0) + + return ( + f"Blocked by {count} detections across {len(detection_types)} type(s) (highest score={highest_score:.2f})" + ) diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 6e463f963..91f3f82d0 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -840,6 +840,45 @@ def get_validator_config(self, name: str) -> Optional[GuardrailsAIValidatorConfi return None +class KServeDetectorConfig(BaseModel): + """Configuration for single KServe detector.""" + + inference_endpoint: str = Field(description="The KServe API endpoint for the detector") + model_name: Optional[str] = Field(default=None, description="The name of the KServe model") + threshold: float = Field(default=0.5, description="Probability threshold for detection") + timeout: int = Field(default=30, description="HTTP request timeout in seconds") + api_key: Optional[str] = Field( + default=None, + description="Bearer token for authenticating to this detector. If not specified, uses KSERVE_API_KEY environment variable.", + ) + safe_labels: List[int] = Field(default_factory=lambda: [0], description="Class indices considered safe") + + +class DetectionsAPIConfig(BaseModel): + """Configuration for Detections API v1/text/contents detector.""" + + inference_endpoint: str = Field( + description="Detections API endpoint URL (e.g., http://service.com/v1/text/contents)" + ) + + detector_id: str = Field(description="Detector ID to send in detector-id header (e.g., dummy-en-pii-v1)") + + threshold: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Detection threshold (0.0-1.0). Block if any detection score >= threshold", + ) + + timeout: int = Field(default=30, gt=0, description="Request timeout in seconds") + + api_key: Optional[str] = Field(default=None, description="Optional API key for authentication (Bearer token)") + + detector_params: Optional[Dict[str, Any]] = Field( + default_factory=dict, description="Optional detector-specific parameters to send in request" + ) + + class TrendMicroRailConfig(BaseModel): """Configuration data for the Trend Micro AI Guard API""" @@ -945,6 +984,16 @@ class RailsConfigData(BaseModel): description="Configuration for Guardrails AI validators.", ) + kserve_detectors: Optional[Dict[str, KServeDetectorConfig]] = Field( + default_factory=dict, + description="Dynamic registry of KServe detectors. Keys are detector names, values are detector configurations.", + ) + + detections_api_detectors: Optional[Dict[str, DetectionsAPIConfig]] = Field( + default_factory=dict, + description="Dynamic registry of Detections API detectors. Keys are detector names, values are detector configurations.", + ) + trend_micro: Optional[TrendMicroRailConfig] = Field( default_factory=TrendMicroRailConfig, description="Configuration for Trend Micro.", diff --git a/tests/test_detector_clients_actions.py b/tests/test_detector_clients_actions.py new file mode 100644 index 000000000..8c55fb1e9 --- /dev/null +++ b/tests/test_detector_clients_actions.py @@ -0,0 +1,919 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for detector_clients/actions.py module. + +Tests cover: +- _run_detections_api_detector() helper function +- detections_api_check_all_detectors() action +- detections_api_check_detector() action +- Message type extraction (user_message, bot_message) +- Parallel execution and result aggregation +- Error categorization (system errors vs content blocks) +""" + +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from nemoguardrails.library.detector_clients.actions import ( + SYSTEM_ERROR_LABELS, + _run_detections_api_detector, + detections_api_check_all_detectors, + detections_api_check_detector, +) +from nemoguardrails.library.detector_clients.base import DetectorResult + + +class TestSystemErrorLabels: + """Tests for SYSTEM_ERROR_LABELS constant""" + + def test_system_error_labels_defined(self): + """Test SYSTEM_ERROR_LABELS constant is properly defined""" + assert isinstance(SYSTEM_ERROR_LABELS, set) + assert len(SYSTEM_ERROR_LABELS) > 0 + + def test_system_error_labels_contains_expected_values(self): + """Test set contains all expected system error labels""" + expected_labels = { + "ERROR", + "HTTP_ERROR", + "TIMEOUT", + "NOT_FOUND", + "VALIDATION_ERROR", + "INVALID_RESPONSE", + "CONFIG_ERROR", + } + + assert expected_labels.issubset(SYSTEM_ERROR_LABELS) + + +class TestRunDetectionsAPIDetector: + """Tests for _run_detections_api_detector() helper function""" + + @pytest.mark.asyncio + async def test_successful_detection(self): + """Test successful detector execution returns result""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + mock_config.threshold = 0.5 + + expected_result = DetectorResult(allowed=True, score=0.3, reason="Safe", label="SAFE", detector="test-detector") + + with patch("nemoguardrails.library.detector_clients.actions.DetectionsAPIClient") as MockClient: + mock_instance = MockClient.return_value + mock_instance.detect = AsyncMock(return_value=expected_result) + + result = await _run_detections_api_detector("test-detector", mock_config, "test text") + + assert result.allowed is True + assert result.score == 0.3 + assert result.label == "SAFE" + + @pytest.mark.asyncio + async def test_constructor_validation_error(self): + """Test ValueError from constructor is caught and returned as ERROR""" + mock_config = Mock() + mock_config.detector_id = "" # Will cause ValueError + + with patch("nemoguardrails.library.detector_clients.actions.DetectionsAPIClient") as MockClient: + MockClient.side_effect = ValueError("detector_id is required") + + result = await _run_detections_api_detector("test-detector", mock_config, "test") + + assert result.allowed is False + assert result.score == 0.0 + assert result.label == "ERROR" + assert "configuration error" in result.reason.lower() + assert result.detector == "test-detector" + + @pytest.mark.asyncio + async def test_detect_returns_error_result(self): + """Test when detect() returns error result (not exception)""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + + error_result = DetectorResult( + allowed=False, score=0.0, reason="Detector timeout", label="TIMEOUT", detector="test-detector" + ) + + with patch("nemoguardrails.library.detector_clients.actions.DetectionsAPIClient") as MockClient: + mock_instance = MockClient.return_value + mock_instance.detect = AsyncMock(return_value=error_result) + + result = await _run_detections_api_detector("test-detector", mock_config, "test") + + assert result.label == "TIMEOUT" + assert result.allowed is False + + +class TestDetectionsAPICheckAllDetectors: + """Tests for detections_api_check_all_detectors() action""" + + @pytest.mark.asyncio + async def test_no_context(self): + """Test with None context""" + result = await detections_api_check_all_detectors(context=None, config=None) + + assert result["allowed"] is False + assert "No configuration" in result["reason"] + + @pytest.mark.asyncio + async def test_no_config_in_params_or_context(self): + """Test when config not provided anywhere""" + context = {"user_message": "test"} + + result = await detections_api_check_all_detectors(context=context, config=None) + + assert result["allowed"] is False + assert "No configuration" in result["reason"] + assert result["detector_count"] == 0 + + @pytest.mark.asyncio + async def test_config_from_context(self): + """Test config extracted from context when not passed as parameter""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + mock_config.rails.config.detections_api_detectors = {} + + context = {"config": mock_config, "user_message": "test"} + + result = await detections_api_check_all_detectors(context=context, config=None) + + # Should find config in context and proceed + assert "No Detections API detectors configured" in result["reason"] + + @pytest.mark.asyncio + async def test_no_message_content(self): + """Test when no user_message or bot_message in context""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + + context = {} # No message fields + + result = await detections_api_check_all_detectors(context=context, config=mock_config) + + assert result["allowed"] is True + assert "No message content" in result["reason"] + assert result["detector_count"] == 0 + + @pytest.mark.asyncio + async def test_user_message_string(self): + """Test extraction of user_message as string""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + mock_config.rails.config.detections_api_detectors = {} + + context = {"user_message": "Hello world"} + + result = await detections_api_check_all_detectors(context=context, config=mock_config) + + # Should extract message successfully + assert "No Detections API detectors configured" in result["reason"] + + @pytest.mark.asyncio + async def test_user_message_dict(self): + """Test extraction of user_message as dict with content field""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + mock_config.rails.config.detections_api_detectors = {} + + context = {"user_message": {"content": "Hello from dict", "role": "user"}} + + result = await detections_api_check_all_detectors(context=context, config=mock_config) + + assert "No Detections API detectors configured" in result["reason"] + + @pytest.mark.asyncio + async def test_bot_message_extraction(self): + """Test bot_message extracted when user_message not present""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + mock_config.rails.config.detections_api_detectors = {} + + context = {"bot_message": "Bot response here"} + + result = await detections_api_check_all_detectors(context=context, config=mock_config) + + # Should extract bot_message + assert "No Detections API detectors configured" in result["reason"] + + @pytest.mark.asyncio + async def test_user_message_priority_over_bot(self): + """Test user_message takes priority when both present""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + + detector_config = Mock() + detector_config.detector_id = "test-id" + detector_config.threshold = 0.5 + + mock_config.rails.config.detections_api_detectors = {"test": detector_config} + + context = {"user_message": "User text", "bot_message": "Bot text"} + + with patch("nemoguardrails.library.detector_clients.actions._run_detections_api_detector") as mock_run: + mock_run.return_value = DetectorResult( + allowed=True, score=0.0, reason="Test", label="SAFE", detector="test" + ) + + await detections_api_check_all_detectors(context=context, config=mock_config) + + # Verify called with user_message, not bot_message + call_args = mock_run.call_args[0] + assert call_args[2] == "User text" # text parameter + + @pytest.mark.asyncio + async def test_config_incomplete_no_rails(self): + """Test when config exists but has no rails attribute""" + mock_config = Mock(spec=[]) # Config without rails attribute + + context = {"user_message": "test"} + + result = await detections_api_check_all_detectors(context=context, config=mock_config) + + assert result["allowed"] is True + # CORRECTED: Match actual string from code + assert "Configuration incomplete" in result["reason"] + + @pytest.mark.asyncio + async def test_config_incomplete_no_config_attr(self): + """Test when config.rails exists but has no config attribute""" + mock_config = Mock() + mock_config.rails = Mock(spec=[]) # rails without config attribute + + context = {"user_message": "test"} + + result = await detections_api_check_all_detectors(context=context, config=mock_config) + + assert result["allowed"] is True + # CORRECTED: Match actual string from code + assert "Configuration incomplete" in result["reason"] + + @pytest.mark.asyncio + async def test_no_detectors_configured(self): + """Test when detections_api_detectors is empty dict""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + mock_config.rails.config.detections_api_detectors = {} + + context = {"user_message": "test"} + + result = await detections_api_check_all_detectors(context=context, config=mock_config) + + assert result["allowed"] is True + assert "No Detections API detectors configured" in result["reason"] + assert result["detector_count"] == 0 + + @pytest.mark.asyncio + async def test_single_detector_passes(self): + """Test single detector that allows content""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + + detector_config = Mock() + mock_config.rails.config.detections_api_detectors = {"toxicity": detector_config} + + context = {"user_message": "Hello world"} + + passing_result = DetectorResult( + allowed=True, score=0.1, reason="Safe content", label="SAFE", detector="toxicity" + ) + + with patch( + "nemoguardrails.library.detector_clients.actions._run_detections_api_detector", return_value=passing_result + ): + result = await detections_api_check_all_detectors(context=context, config=mock_config) + + assert result["allowed"] is True + assert "Approved by all 1" in result["reason"] + assert len(result["allowing_detectors"]) == 1 + assert len(result["blocking_detectors"]) == 0 + assert result["detector_count"] == 1 + + @pytest.mark.asyncio + async def test_single_detector_blocks(self): + """Test single detector that blocks content""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + + detector_config = Mock() + mock_config.rails.config.detections_api_detectors = {"toxicity": detector_config} + + context = {"user_message": "bad content"} + + blocking_result = DetectorResult( + allowed=False, score=0.95, reason="Toxic content detected", label="toxicity:profanity", detector="toxicity" + ) + + with patch( + "nemoguardrails.library.detector_clients.actions._run_detections_api_detector", return_value=blocking_result + ): + result = await detections_api_check_all_detectors(context=context, config=mock_config) + + assert result["allowed"] is False + assert "Blocked by 1" in result["reason"] + assert len(result["blocking_detectors"]) == 1 + assert len(result["allowing_detectors"]) == 0 + assert result["detector_count"] == 1 + + @pytest.mark.asyncio + async def test_multiple_detectors_all_pass(self): + """Test multiple detectors all allowing content""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + + mock_config.rails.config.detections_api_detectors = {"toxicity": Mock(), "jailbreak": Mock(), "pii": Mock()} + + context = {"user_message": "safe message"} + + passing_result = DetectorResult(allowed=True, score=0.1, reason="Safe", label="SAFE", detector="test") + + with patch( + "nemoguardrails.library.detector_clients.actions._run_detections_api_detector", return_value=passing_result + ): + result = await detections_api_check_all_detectors(context=context, config=mock_config) + + assert result["allowed"] is True + assert "Approved by all 3" in result["reason"] + assert len(result["allowing_detectors"]) == 3 + assert result["detector_count"] == 3 + + @pytest.mark.asyncio + async def test_multiple_detectors_some_block(self): + """Test multiple detectors with mixed results""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + + mock_config.rails.config.detections_api_detectors = {"toxicity": Mock(), "jailbreak": Mock(), "pii": Mock()} + + context = {"user_message": "test"} + + async def mock_detector(name, config, text): + if name == "toxicity": + return DetectorResult( + allowed=False, score=0.9, reason="Toxic", label="toxicity:profanity", detector=name + ) + elif name == "jailbreak": + return DetectorResult( + allowed=False, score=0.8, reason="Jailbreak", label="jailbreak:injection", detector=name + ) + else: + return DetectorResult(allowed=True, score=0.1, reason="Safe", label="SAFE", detector=name) + + with patch( + "nemoguardrails.library.detector_clients.actions._run_detections_api_detector", side_effect=mock_detector + ): + result = await detections_api_check_all_detectors(context=context, config=mock_config) + + assert result["allowed"] is False + assert "Blocked by 2" in result["reason"] + assert len(result["blocking_detectors"]) == 2 + assert len(result["allowing_detectors"]) == 1 + assert result["detector_count"] == 3 + + @pytest.mark.asyncio + async def test_system_error_handling(self): + """Test detector with system error (TIMEOUT) goes to unavailable list""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + + mock_config.rails.config.detections_api_detectors = {"detector1": Mock(), "detector2": Mock()} + + context = {"user_message": "test"} + + async def mock_detector(name, config, text): + if name == "detector1": + return DetectorResult(allowed=False, score=0.0, reason="Timeout", label="TIMEOUT", detector=name) + else: + return DetectorResult(allowed=True, score=0.1, reason="Safe", label="SAFE", detector=name) + + with patch( + "nemoguardrails.library.detector_clients.actions._run_detections_api_detector", side_effect=mock_detector + ): + result = await detections_api_check_all_detectors(context=context, config=mock_config) + + assert result["allowed"] is False + assert "System error" in result["reason"] + assert result["unavailable_detectors"] == ["detector1"] + assert len(result["blocking_detectors"]) == 0 # TIMEOUT not a content block + assert len(result["allowing_detectors"]) == 1 + + @pytest.mark.asyncio + async def test_http_error_classified_as_system_error(self): + """Test HTTP_ERROR label goes to system errors, not content blocks""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + + mock_config.rails.config.detections_api_detectors = {"test": Mock()} + + context = {"user_message": "test"} + + http_error_result = DetectorResult( + allowed=False, score=0.0, reason="HTTP error", label="HTTP_ERROR", detector="test" + ) + + with patch( + "nemoguardrails.library.detector_clients.actions._run_detections_api_detector", + return_value=http_error_result, + ): + result = await detections_api_check_all_detectors(context=context, config=mock_config) + + # HTTP_ERROR should go to system errors, not content blocks + assert result["unavailable_detectors"] == ["test"] + assert len(result["blocking_detectors"]) == 0 + + @pytest.mark.asyncio + async def test_exception_from_gather(self): + """Test Exception raised during asyncio.gather is handled""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + + mock_config.rails.config.detections_api_detectors = {"test": Mock()} + + context = {"user_message": "test"} + + # CORRECTED: Make the mock function raise exception + # asyncio.gather with return_exceptions=True will catch it and return it as a result + with patch( + "nemoguardrails.library.detector_clients.actions._run_detections_api_detector", + side_effect=RuntimeError("Unexpected error"), + ): + result = await detections_api_check_all_detectors(context=context, config=mock_config) + + # Exception should be converted to DetectorResult with ERROR label + assert result["allowed"] is False + assert "System error" in result["reason"] + + @pytest.mark.asyncio + async def test_parallel_execution(self): + """Test detectors run in parallel via asyncio.gather""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + + mock_config.rails.config.detections_api_detectors = { + "detector1": Mock(), + "detector2": Mock(), + "detector3": Mock(), + } + + context = {"user_message": "test"} + + call_count = 0 + + async def mock_detector(name, config, text): + nonlocal call_count + call_count += 1 + return DetectorResult(allowed=True, score=0.0, reason="Test", label="SAFE", detector=name) + + with patch( + "nemoguardrails.library.detector_clients.actions._run_detections_api_detector", side_effect=mock_detector + ): + result = await detections_api_check_all_detectors(context=context, config=mock_config) + + # All 3 detectors should have been called + assert call_count == 3 + assert result["detector_count"] == 3 + + @pytest.mark.asyncio + async def test_detector_names_not_shadowed(self): + """Test detector_names variable not incorrectly shadowed""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + + mock_config.rails.config.detections_api_detectors = {"detector1": Mock(), "detector2": Mock()} + + context = {"user_message": "test"} + + async def mock_detector(name, config, text): + if name == "detector1": + return DetectorResult(allowed=False, score=0.9, reason="Block", label="toxic:bad", detector=name) + else: + return DetectorResult(allowed=True, score=0.1, reason="Safe", label="SAFE", detector=name) + + with patch( + "nemoguardrails.library.detector_clients.actions._run_detections_api_detector", side_effect=mock_detector + ): + result = await detections_api_check_all_detectors(context=context, config=mock_config) + + # Verify blocking detector name appears correctly + assert "detector1" in result["reason"] + assert len(result["blocking_detectors"]) == 1 + # CORRECTED: blocking_detectors is list of DetectorResult.dict(), access as dict + assert result["blocking_detectors"][0]["detector"] == "detector1" + + +class TestDetectionsAPICheckDetector: + """Tests for detections_api_check_detector() action""" + + @pytest.mark.asyncio + async def test_no_config(self): + """Test with no configuration""" + result = await detections_api_check_detector( + context={"user_message": "test"}, config=None, detector_name="toxicity" + ) + + assert result["allowed"] is False + assert "No configuration" in result["reason"] + assert result["label"] == "NO_CONFIG" + + @pytest.mark.asyncio + async def test_no_message_content(self): + """Test when no message in context""" + mock_config = Mock() + + result = await detections_api_check_detector(context={}, config=mock_config, detector_name="toxicity") + + assert result["allowed"] is True + assert "No message content" in result["reason"] + assert result["label"] == "NO_CONTENT" + + @pytest.mark.asyncio + async def test_config_incomplete(self): + """Test when config structure is incomplete""" + mock_config = Mock(spec=[]) # No rails attribute + + result = await detections_api_check_detector( + context={"user_message": "test"}, config=mock_config, detector_name="toxicity" + ) + + assert result["allowed"] is True + # CORRECTED: Match actual string from code + assert "Configuration incomplete" in result["reason"] + assert result["label"] == "CONFIG_INCOMPLETE" + + @pytest.mark.asyncio + async def test_detector_not_configured(self): + """Test when requested detector not in config""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + mock_config.rails.config.detections_api_detectors = { + "jailbreak": Mock() # Different detector + } + + context = {"user_message": "test"} + + result = await detections_api_check_detector( + context=context, + config=mock_config, + detector_name="toxicity", # Not in config + ) + + assert result["allowed"] is True + assert result["label"] == "NOT_CONFIGURED" + assert "not configured" in result["reason"].lower() + + @pytest.mark.asyncio + async def test_detector_config_is_none(self): + """Test when detector config value is None""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + mock_config.rails.config.detections_api_detectors = { + "toxicity": None # Config is None + } + + context = {"user_message": "test"} + + result = await detections_api_check_detector(context=context, config=mock_config, detector_name="toxicity") + + assert result["allowed"] is True + assert result["label"] == "NONE" + assert "no configuration" in result["reason"].lower() + + @pytest.mark.asyncio + async def test_detector_successful_detection(self): + """Test successful detector execution""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + + detector_config = Mock() + mock_config.rails.config.detections_api_detectors = {"toxicity": detector_config} + + context = {"user_message": "test message"} + + detection_result = DetectorResult( + allowed=False, score=0.88, reason="Blocked", label="toxicity:hate", detector="toxicity" + ) + + with patch( + "nemoguardrails.library.detector_clients.actions._run_detections_api_detector", + return_value=detection_result, + ): + result = await detections_api_check_detector(context=context, config=mock_config, detector_name="toxicity") + + assert result["allowed"] is False + assert result["score"] == 0.88 + assert result["label"] == "toxicity:hate" + + @pytest.mark.asyncio + async def test_detector_with_bot_message(self): + """Test detector works with bot_message (output guardrail)""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + + detector_config = Mock() + mock_config.rails.config.detections_api_detectors = {"toxicity": detector_config} + + context = {"bot_message": "bot response text"} + + with patch("nemoguardrails.library.detector_clients.actions._run_detections_api_detector") as mock_run: + mock_run.return_value = DetectorResult( + allowed=True, score=0.0, reason="Safe", label="SAFE", detector="toxicity" + ) + + result = await detections_api_check_detector(context=context, config=mock_config, detector_name="toxicity") + + # Verify called with bot_message text + call_args = mock_run.call_args[0] + assert call_args[2] == "bot response text" + + assert result["allowed"] is True + + +class TestResultAggregation: + """Tests for result aggregation logic in check_all_detectors""" + + @pytest.mark.asyncio + async def test_system_error_with_content_blocks(self): + """Test system error and content block both present""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + + mock_config.rails.config.detections_api_detectors = { + "detector1": Mock(), + "detector2": Mock(), + "detector3": Mock(), + } + + context = {"user_message": "test"} + + async def mock_detector(name, config, text): + if name == "detector1": + # System error + return DetectorResult(allowed=False, score=0.0, reason="Error", label="ERROR", detector=name) + elif name == "detector2": + # Content block + return DetectorResult(allowed=False, score=0.9, reason="Toxic", label="toxic:bad", detector=name) + else: + # Allowing + return DetectorResult(allowed=True, score=0.1, reason="Safe", label="SAFE", detector=name) + + with patch( + "nemoguardrails.library.detector_clients.actions._run_detections_api_detector", side_effect=mock_detector + ): + result = await detections_api_check_all_detectors(context=context, config=mock_config) + + # System error takes precedence in response + assert result["allowed"] is False + assert "System error" in result["reason"] + assert result["unavailable_detectors"] == ["detector1"] + assert len(result["blocking_detectors"]) == 1 # detector2 content block + assert len(result["allowing_detectors"]) == 1 # detector3 + + @pytest.mark.asyncio + async def test_all_system_errors(self): + """Test when all detectors have system errors""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + + mock_config.rails.config.detections_api_detectors = {"detector1": Mock(), "detector2": Mock()} + + context = {"user_message": "test"} + + async def mock_detector(name, config, text): + return DetectorResult(allowed=False, score=0.0, reason="Error", label="TIMEOUT", detector=name) + + with patch( + "nemoguardrails.library.detector_clients.actions._run_detections_api_detector", side_effect=mock_detector + ): + result = await detections_api_check_all_detectors(context=context, config=mock_config) + + assert result["allowed"] is False + assert "2 Detections API detector(s) unavailable" in result["reason"] + assert len(result["unavailable_detectors"]) == 2 + assert len(result["blocking_detectors"]) == 0 + assert len(result["allowing_detectors"]) == 0 + + @pytest.mark.asyncio + async def test_blocking_detector_names_distinct_from_all_detector_names(self): + """Test that blocking_detector_names are subset of all detector_names""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + + mock_config.rails.config.detections_api_detectors = {"toxicity": Mock(), "jailbreak": Mock(), "pii": Mock()} + + context = {"user_message": "test"} + + async def mock_detector(name, config, text): + if name == "toxicity": + return DetectorResult(allowed=False, score=0.9, reason="Block", label="toxic:bad", detector=name) + else: + return DetectorResult(allowed=True, score=0.1, reason="Safe", label="SAFE", detector=name) + + with patch( + "nemoguardrails.library.detector_clients.actions._run_detections_api_detector", side_effect=mock_detector + ): + result = await detections_api_check_all_detectors(context=context, config=mock_config) + + # Only toxicity should be in reason + assert "toxicity" in result["reason"] + assert "jailbreak" not in result["reason"] + assert "pii" not in result["reason"] + + +class TestMessageTypeExtraction: + """Tests for message type extraction (input/output guardrails)""" + + @pytest.mark.asyncio + async def test_extracts_user_message_string(self): + """Test user_message extracted when it's a string""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + mock_config.rails.config.detections_api_detectors = {"test": Mock()} + + context = {"user_message": "user input text"} + + with patch("nemoguardrails.library.detector_clients.actions._run_detections_api_detector") as mock_run: + mock_run.return_value = DetectorResult( + allowed=True, score=0.0, reason="Test", label="SAFE", detector="test" + ) + + await detections_api_check_all_detectors(context=context, config=mock_config) + + # Verify text parameter + assert mock_run.call_args[0][2] == "user input text" + + @pytest.mark.asyncio + async def test_extracts_user_message_dict_with_content(self): + """Test user_message extracted from dict with content field""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + mock_config.rails.config.detections_api_detectors = {"test": Mock()} + + context = {"user_message": {"content": "message content", "role": "user", "other": "fields"}} + + with patch("nemoguardrails.library.detector_clients.actions._run_detections_api_detector") as mock_run: + mock_run.return_value = DetectorResult( + allowed=True, score=0.0, reason="Test", label="SAFE", detector="test" + ) + + await detections_api_check_all_detectors(context=context, config=mock_config) + + assert mock_run.call_args[0][2] == "message content" + + @pytest.mark.asyncio + async def test_extracts_bot_message_when_no_user_message(self): + """Test bot_message used when user_message not present""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + mock_config.rails.config.detections_api_detectors = {"test": Mock()} + + context = {"bot_message": "bot response"} + + with patch("nemoguardrails.library.detector_clients.actions._run_detections_api_detector") as mock_run: + mock_run.return_value = DetectorResult( + allowed=True, score=0.0, reason="Test", label="SAFE", detector="test" + ) + + await detections_api_check_all_detectors(context=context, config=mock_config) + + assert mock_run.call_args[0][2] == "bot response" + + @pytest.mark.asyncio + async def test_empty_message_dict_without_content(self): + """Test dict message without content field""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + + context = {"user_message": {"role": "user"}} # No content field + + result = await detections_api_check_all_detectors(context=context, config=mock_config) + + # Empty content string + assert result["allowed"] is True + assert "No message content" in result["reason"] + + +class TestReturnFormatConsistency: + """Tests for return format consistency""" + + @pytest.mark.asyncio + async def test_all_returns_have_required_fields(self): + """Test all error returns have consistent AggregatedDetectorResult structure""" + mock_config = Mock() + + # Test no config + result1 = await detections_api_check_all_detectors(context={}, config=None) + assert "allowed" in result1 + assert "reason" in result1 + assert "blocking_detectors" in result1 + assert "allowing_detectors" in result1 + assert "detector_count" in result1 + + # Test no message + mock_config.rails = Mock() + mock_config.rails.config = Mock() + result2 = await detections_api_check_all_detectors(context={}, config=mock_config) + assert "allowed" in result2 + assert "detector_count" in result2 + + # Test incomplete config + mock_config2 = Mock(spec=[]) + result3 = await detections_api_check_all_detectors(context={"user_message": "test"}, config=mock_config2) + assert "allowed" in result3 + assert "detector_count" in result3 + + @pytest.mark.asyncio + async def test_check_detector_returns_detector_result_format(self): + """Test check_detector always returns DetectorResult.dict() format""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + + detector_config = Mock() + mock_config.rails.config.detections_api_detectors = {"test": detector_config} + + context = {"user_message": "test"} + + expected_result = DetectorResult(allowed=True, score=0.2, reason="Safe", label="SAFE", detector="test") + + with patch( + "nemoguardrails.library.detector_clients.actions._run_detections_api_detector", return_value=expected_result + ): + result = await detections_api_check_detector(context=context, config=mock_config, detector_name="test") + + # Should have DetectorResult fields + assert "allowed" in result + assert "score" in result + assert "reason" in result + assert "label" in result + assert "detector" in result + + +class TestDefaultParameters: + """Tests for default parameter values""" + + @pytest.mark.asyncio + async def test_check_detector_default_detector_name(self): + """Test check_detector uses default detector_name""" + mock_config = Mock() + mock_config.rails = Mock() + mock_config.rails.config = Mock() + + # Default is "mock_pii" according to function signature + mock_config.rails.config.detections_api_detectors = {"mock_pii": Mock()} + + context = {"user_message": "test"} + + with patch("nemoguardrails.library.detector_clients.actions._run_detections_api_detector") as mock_run: + mock_run.return_value = DetectorResult( + allowed=True, score=0.0, reason="Test", label="SAFE", detector="mock_pii" + ) + + # Call without detector_name parameter + result = await detections_api_check_detector(context=context, config=mock_config) + + # Should use default "mock_pii" + assert mock_run.call_args[0][0] == "mock_pii" diff --git a/tests/test_detector_clients_base.py b/tests/test_detector_clients_base.py new file mode 100644 index 000000000..370ce426f --- /dev/null +++ b/tests/test_detector_clients_base.py @@ -0,0 +1,484 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for detector_clients/base.py module. + +Tests cover: +- DetectorResult model validation +- AggregatedDetectorResult model validation +- BaseDetectorClient error handling +- HTTP session cleanup +""" + +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from pydantic import ValidationError + +from nemoguardrails.library.detector_clients.base import ( + AggregatedDetectorResult, + BaseDetectorClient, + DetectorResult, + cleanup_http_session, +) + + +class TestDetectorResult: + """Tests for DetectorResult model""" + + def test_valid_detector_result(self): + """Test creating valid DetectorResult""" + result = DetectorResult( + allowed=True, + score=0.75, + reason="Test passed", + label="SAFE", + detector="test-detector", + metadata={"key": "value"}, + ) + + assert result.allowed is True + assert result.score == 0.75 + assert result.reason == "Test passed" + assert result.label == "SAFE" + assert result.detector == "test-detector" + assert result.metadata == {"key": "value"} + + def test_detector_result_without_metadata(self): + """Test DetectorResult with optional metadata as None""" + result = DetectorResult( + allowed=False, score=0.95, reason="Blocked", label="TOXIC", detector="toxicity-detector" + ) + + assert result.metadata is None + + def test_detector_result_missing_required_fields(self): + """Test that missing required fields raises ValidationError""" + with pytest.raises(ValidationError): + DetectorResult( + allowed=True, + score=0.5, + # Missing: reason, label, detector + ) + + def test_detector_result_type_coercion(self): + """Test Pydantic type coercion""" + result = DetectorResult( + allowed="yes", # String coerced to bool + score="0.8", # String coerced to float + reason="Test", + label="SAFE", + detector="test", + ) + + assert result.allowed is True + assert result.score == 0.8 + + def test_detector_result_to_dict(self): + """Test .dict() serialization""" + result = DetectorResult( + allowed=False, score=0.9, reason="Test", label="BLOCK", detector="test", metadata={"foo": "bar"} + ) + + result_dict = result.dict() + + assert isinstance(result_dict, dict) + assert result_dict["allowed"] is False + assert result_dict["score"] == 0.9 + assert result_dict["metadata"] == {"foo": "bar"} + + +class TestAggregatedDetectorResult: + """Tests for AggregatedDetectorResult model""" + + def test_valid_aggregated_result(self): + """Test creating valid AggregatedDetectorResult""" + blocking = DetectorResult(allowed=False, score=0.9, reason="Toxic", label="TOXIC", detector="toxicity") + + allowing = DetectorResult(allowed=True, score=0.1, reason="Safe", label="SAFE", detector="pii") + + result = AggregatedDetectorResult( + allowed=False, + reason="Blocked by 1 detector", + blocking_detectors=[blocking], + allowing_detectors=[allowing], + detector_count=2, + unavailable_detectors=None, + ) + + assert result.allowed is False + assert len(result.blocking_detectors) == 1 + assert len(result.allowing_detectors) == 1 + assert result.detector_count == 2 + + def test_aggregated_result_with_defaults(self): + """Test AggregatedDetectorResult with default list values""" + result = AggregatedDetectorResult(allowed=True, reason="All passed", detector_count=0) + + assert result.blocking_detectors == [] + assert result.allowing_detectors == [] + assert result.unavailable_detectors is None + + def test_aggregated_result_with_unavailable(self): + """Test tracking unavailable detectors""" + result = AggregatedDetectorResult( + allowed=False, + reason="System error", + blocking_detectors=[], + allowing_detectors=[], + detector_count=2, + unavailable_detectors=["detector1", "detector2"], + ) + + assert result.unavailable_detectors == ["detector1", "detector2"] + + +class ConcreteDetectorClient(BaseDetectorClient): + """Concrete implementation of BaseDetectorClient for testing""" + + async def detect(self, text: str) -> DetectorResult: + return DetectorResult(allowed=True, score=0.0, reason="Test", label="TEST", detector=self.detector_name) + + def build_request(self, text: str): + return {"text": text} + + def parse_response(self, response, http_status): + return DetectorResult(allowed=True, score=0.0, reason="Test", label="TEST", detector=self.detector_name) + + +class TestBaseDetectorClient: + """Tests for BaseDetectorClient base class""" + + def test_init_with_full_config(self): + """Test initialization with complete configuration""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com/api" + mock_config.timeout = 60 + mock_config.api_key = "test-key-123" + + client = ConcreteDetectorClient(mock_config, "test-detector") + + assert client.detector_name == "test-detector" + assert client.endpoint == "http://test.com/api" + assert client.timeout == 60 + assert client.api_key == "test-key-123" + + def test_init_with_minimal_config(self): + """Test initialization with minimal config (using defaults)""" + from types import SimpleNamespace + + # Use SimpleNamespace instead of Mock - only has attributes we set + mock_config = SimpleNamespace() + mock_config.inference_endpoint = "http://test.com" + # Don't set timeout or api_key - they won't exist + + client = ConcreteDetectorClient(mock_config, "test-detector") + + assert client.endpoint == "http://test.com" + assert client.timeout == 30 # Default + assert client.api_key is None # Default + + +class TestHandleError: + """Tests for BaseDetectorClient._handle_error() method""" + + def test_handle_timeout_error(self): + """Test timeout error creates TIMEOUT label and appropriate reason""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + + client = ConcreteDetectorClient(mock_config, "test-detector") + error = Exception("Request timeout after 30s") + + result = client._handle_error(error, "test-detector") + + assert result.allowed is False + assert result.score == 0.0 + assert result.label == "TIMEOUT" + assert "timeout" in result.reason.lower() + assert result.detector == "test-detector" + assert result.metadata["error"] == "Request timeout after 30s" + + def test_handle_http_error(self): + """Test HTTP error creates HTTP_ERROR label""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + + client = ConcreteDetectorClient(mock_config, "test-detector") + error = Exception("HTTP 500: Internal Server Error") + + result = client._handle_error(error, "test-detector") + + assert result.allowed is False + assert result.score == 0.0 + assert result.label == "HTTP_ERROR" + # CORRECTED: Actual reason uses "service error" not "HTTP error" + assert "service error" in result.reason + assert "HTTP 500" in result.reason + assert result.metadata["error"] == "HTTP 500: Internal Server Error" + + def test_handle_http_404_error(self): + """Test HTTP 404 error""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + + client = ConcreteDetectorClient(mock_config, "test-detector") + error = Exception("HTTP 404: Not Found") + + result = client._handle_error(error, "test-detector") + + assert result.allowed is False + assert result.label == "HTTP_ERROR" + assert "HTTP 404" in result.reason + + def test_handle_generic_error(self): + """Test generic error creates ERROR label""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + + client = ConcreteDetectorClient(mock_config, "test-detector") + error = Exception("Something went wrong") + + result = client._handle_error(error, "test-detector") + + assert result.allowed is False + assert result.score == 0.0 + assert result.label == "ERROR" + assert result.reason == "test-detector error: Something went wrong" + assert result.metadata["error"] == "Something went wrong" + + def test_handle_error_with_special_characters(self): + """Test error messages with special characters are handled""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + + client = ConcreteDetectorClient(mock_config, "test-detector") + error = Exception("Error: 'quoted' and \"double-quoted\" text") + + result = client._handle_error(error, "test-detector") + + assert result.allowed is False + assert "quoted" in result.reason + + +class TestCleanupHttpSession: + """Tests for cleanup_http_session() function""" + + @pytest.mark.asyncio + async def test_cleanup_when_session_exists(self): + """Test cleanup closes existing session""" + from nemoguardrails.library.detector_clients import base + + # Create a mock session + mock_session = AsyncMock() + base._http_session = mock_session + + await cleanup_http_session() + + # Verify session was closed + mock_session.close.assert_called_once() + assert base._http_session is None + + @pytest.mark.asyncio + async def test_cleanup_when_no_session(self): + """Test cleanup is safe when no session exists""" + from nemoguardrails.library.detector_clients import base + + base._http_session = None + + # Should not raise + await cleanup_http_session() + + assert base._http_session is None + + @pytest.mark.asyncio + async def test_cleanup_idempotent(self): + """Test cleanup can be called multiple times safely""" + from nemoguardrails.library.detector_clients import base + + mock_session = AsyncMock() + base._http_session = mock_session + + # Call cleanup twice + await cleanup_http_session() + await cleanup_http_session() + + # Should only close once (session is None on second call) + mock_session.close.assert_called_once() + assert base._http_session is None + + +class TestCallEndpoint: + """Tests for BaseDetectorClient._call_endpoint() method""" + + @pytest.mark.asyncio + async def test_successful_post_request(self): + """Test successful HTTP POST returns data and status""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + + client = ConcreteDetectorClient(mock_config, "test-detector") + + # Mock the HTTP response properly + mock_response = Mock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={"result": "success"}) + + # Mock session.post to return async context manager + mock_post_cm = AsyncMock() + mock_post_cm.__aenter__ = AsyncMock(return_value=mock_response) + mock_post_cm.__aexit__ = AsyncMock(return_value=None) + + mock_session = Mock() + mock_session.post = Mock(return_value=mock_post_cm) + + with patch("nemoguardrails.library.detector_clients.base._http_session", mock_session): + data, status = await client._call_endpoint( + endpoint="http://test.com/api", payload={"text": "test"}, timeout=30 + ) + + assert status == 200 + assert data == {"result": "success"} + + @pytest.mark.asyncio + async def test_post_request_with_headers(self): + """Test request includes custom headers""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.api_key = "secret-key" + + client = ConcreteDetectorClient(mock_config, "test-detector") + + mock_response = Mock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={}) + + mock_post_cm = AsyncMock() + mock_post_cm.__aenter__ = AsyncMock(return_value=mock_response) + mock_post_cm.__aexit__ = AsyncMock(return_value=None) + + mock_session = Mock() + mock_session.post = Mock(return_value=mock_post_cm) + + with patch("nemoguardrails.library.detector_clients.base._http_session", mock_session): + await client._call_endpoint( + endpoint="http://test.com/api", payload={"text": "test"}, timeout=30, headers={"Custom-Header": "value"} + ) + + # Verify headers were passed + call_kwargs = mock_session.post.call_args[1] + assert "Custom-Header" in call_kwargs["headers"] + assert call_kwargs["headers"]["Custom-Header"] == "value" + + @pytest.mark.asyncio + async def test_post_request_timeout(self): + """Test timeout raises appropriate exception""" + import asyncio + + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + + client = ConcreteDetectorClient(mock_config, "test-detector") + + # Mock session.post to raise timeout + mock_session = Mock() + mock_session.post = Mock(side_effect=asyncio.TimeoutError()) + + with patch("nemoguardrails.library.detector_clients.base._http_session", mock_session): + with pytest.raises(Exception, match="timeout"): + await client._call_endpoint(endpoint="http://test.com/api", payload={"text": "test"}, timeout=30) + + @pytest.mark.asyncio + async def test_post_request_non_200_status(self): + """Test non-200 status raises exception""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + + client = ConcreteDetectorClient(mock_config, "test-detector") + + mock_response = Mock() + mock_response.status = 500 + mock_response.text = AsyncMock(return_value="Internal Server Error") + + mock_post_cm = AsyncMock() + mock_post_cm.__aenter__ = AsyncMock(return_value=mock_response) + mock_post_cm.__aexit__ = AsyncMock(return_value=None) + + mock_session = Mock() + mock_session.post = Mock(return_value=mock_post_cm) + + with patch("nemoguardrails.library.detector_clients.base._http_session", mock_session): + with pytest.raises(Exception, match="500"): + await client._call_endpoint(endpoint="http://test.com/api", payload={"text": "test"}, timeout=30) + + @pytest.mark.asyncio + async def test_post_request_with_api_key(self): + """Test Authorization header added when api_key configured""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.api_key = "test-api-key" + + client = ConcreteDetectorClient(mock_config, "test-detector") + + mock_response = Mock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={}) + + mock_post_cm = AsyncMock() + mock_post_cm.__aenter__ = AsyncMock(return_value=mock_response) + mock_post_cm.__aexit__ = AsyncMock(return_value=None) + + mock_session = Mock() + mock_session.post = Mock(return_value=mock_post_cm) + + with patch("nemoguardrails.library.detector_clients.base._http_session", mock_session): + await client._call_endpoint(endpoint="http://test.com/api", payload={"text": "test"}, timeout=30) + + # Verify Authorization header + call_kwargs = mock_session.post.call_args[1] + assert "Authorization" in call_kwargs["headers"] + assert call_kwargs["headers"]["Authorization"] == "Bearer test-api-key" + + @pytest.mark.asyncio + async def test_post_request_with_env_api_key(self): + """Test fallback to environment variable for API key""" + from types import SimpleNamespace + + # CORRECTED: Use SimpleNamespace without api_key attribute + mock_config = SimpleNamespace() + mock_config.inference_endpoint = "http://test.com" + # Don't set api_key - it won't exist, forcing env var lookup + + client = ConcreteDetectorClient(mock_config, "test-detector") + + mock_response = Mock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={}) + + mock_post_cm = AsyncMock() + mock_post_cm.__aenter__ = AsyncMock(return_value=mock_response) + mock_post_cm.__aexit__ = AsyncMock(return_value=None) + + mock_session = Mock() + mock_session.post = Mock(return_value=mock_post_cm) + + with patch("nemoguardrails.library.detector_clients.base._http_session", mock_session): + with patch.dict("os.environ", {"DETECTIONS_API_KEY": "env-key-456"}): + await client._call_endpoint(endpoint="http://test.com/api", payload={"text": "test"}, timeout=30) + + # Verify env var key used + call_kwargs = mock_session.post.call_args[1] + assert call_kwargs["headers"]["Authorization"] == "Bearer env-key-456" diff --git a/tests/test_detector_clients_detections_api.py b/tests/test_detector_clients_detections_api.py new file mode 100644 index 000000000..ed3d452c7 --- /dev/null +++ b/tests/test_detector_clients_detections_api.py @@ -0,0 +1,810 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for detector_clients/detections_api.py module. + +Tests cover: +- DetectionsAPIClient initialization +- Request payload building +- Response parsing for all scenarios +- Error handling (HTTP errors, invalid responses) +- Helper methods +""" + +from unittest.mock import Mock, patch + +import pytest + +from nemoguardrails.library.detector_clients.detections_api import DetectionsAPIClient + + +class TestDetectionsAPIClientInit: + """Tests for DetectionsAPIClient initialization""" + + def test_init_with_valid_config(self): + """Test initialization with complete configuration""" + mock_config = Mock() + mock_config.inference_endpoint = "http://detector.com/api" + mock_config.detector_id = "test-detector-v1" + mock_config.threshold = 0.8 + mock_config.timeout = 60 + mock_config.detector_params = {"param1": "value1"} + mock_config.api_key = "test-key" + + client = DetectionsAPIClient(mock_config, "test-detector") + + assert client.detector_name == "test-detector" + assert client.endpoint == "http://detector.com/api" + assert client.detector_id == "test-detector-v1" + assert client.threshold == 0.8 + assert client.timeout == 60 + assert client.detector_params == {"param1": "value1"} + assert client.api_key == "test-key" + + def test_init_with_defaults(self): + """Test initialization uses default values when not specified""" + from types import SimpleNamespace + + # CORRECTED: Use SimpleNamespace without optional attributes + mock_config = SimpleNamespace() + mock_config.inference_endpoint = "http://detector.com" + mock_config.detector_id = "test-id" + # Don't set threshold, detector_params, timeout - getattr will use defaults + + client = DetectionsAPIClient(mock_config, "test-detector") + + assert client.threshold == 0.5 # Default + assert client.detector_params == {} # Default + assert client.timeout == 30 # Default from BaseDetectorClient + + def test_init_missing_detector_id_raises_error(self): + """Test initialization fails when detector_id is empty string""" + mock_config = Mock() + mock_config.inference_endpoint = "http://detector.com" + mock_config.detector_id = "" # Empty string + + with pytest.raises(ValueError, match="detector_id is required"): + DetectionsAPIClient(mock_config, "test-detector") + + +class TestBuildRequest: + """Tests for DetectionsAPIClient.build_request()""" + + def test_build_request_basic(self): + """Test request payload format""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + mock_config.detector_params = {} + + client = DetectionsAPIClient(mock_config, "test-detector") + + request = client.build_request("test text content") + + assert request == {"contents": ["test text content"], "detector_params": {}} + + def test_build_request_with_params(self): + """Test request includes custom detector_params""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + mock_config.detector_params = {"sensitivity": "high", "language": "en"} + + client = DetectionsAPIClient(mock_config, "test-detector") + + request = client.build_request("test text") + + assert request["detector_params"] == {"sensitivity": "high", "language": "en"} + + def test_build_request_with_special_characters(self): + """Test request handles special characters in text""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + + client = DetectionsAPIClient(mock_config, "test-detector") + + text = "Text with 'quotes' and \"double quotes\" and\nnewlines" + request = client.build_request(text) + + assert request["contents"][0] == text + + +class TestParseResponse: + """Tests for DetectionsAPIClient.parse_response()""" + + def test_parse_response_http_200_with_detection_above_threshold(self): + """Test successful response with detection that exceeds threshold""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "pii-detector" + mock_config.threshold = 0.5 + + client = DetectionsAPIClient(mock_config, "pii-detector") + + response = [ + [ + { + "start": 10, + "end": 25, + "text": "test@email.com", + "detection_type": "pii", + "detection": "EmailAddress", + "score": 0.95, + "evidence": {}, + "metadata": {}, + } + ] + ] + + result = client.parse_response(response, 200) + + assert result.allowed is False + assert result.score == 0.95 + assert result.label == "pii:EmailAddress" + assert "Blocked by pii:EmailAddress" in result.reason + assert result.metadata["detection_count"] == 1 + + def test_parse_response_http_200_below_threshold(self): + """Test response with detections all below threshold""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + mock_config.threshold = 0.8 + + client = DetectionsAPIClient(mock_config, "test-detector") + + response = [ + [ + { + "start": 0, + "end": 10, + "text": "test", + "detection_type": "toxicity", + "detection": "mild", + "score": 0.3, + "evidence": {}, + "metadata": {}, + } + ] + ] + + result = client.parse_response(response, 200) + + assert result.allowed is True + assert result.label == "BELOW_THRESHOLD" + assert "below threshold" in result.reason + assert result.metadata["detection_count"] == 0 + assert result.metadata["total_detections"] == 1 + + def test_parse_response_http_200_no_detections(self): + """Test response with empty detection list""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + + client = DetectionsAPIClient(mock_config, "test-detector") + + response = [[]] # Empty detections + + result = client.parse_response(response, 200) + + assert result.allowed is True + assert result.score == 0.0 + assert result.label == "NONE" + assert result.reason == "No detections found" + assert result.metadata["detection_count"] == 0 + + def test_parse_response_http_200_multiple_detections(self): + """Test response with multiple detections above threshold""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + mock_config.threshold = 0.5 + + client = DetectionsAPIClient(mock_config, "test-detector") + + response = [ + [ + { + "start": 0, + "end": 10, + "text": "bad word", + "detection_type": "toxicity", + "detection": "profanity", + "score": 0.9, + "evidence": {}, + "metadata": {}, + }, + { + "start": 20, + "end": 30, + "text": "attack", + "detection_type": "toxicity", + "detection": "violence", + "score": 0.8, + "evidence": {}, + "metadata": {}, + }, + ] + ] + + result = client.parse_response(response, 200) + + assert result.allowed is False + assert result.score == 0.9 # Highest score + assert "2 detections" in result.reason + assert result.metadata["detection_count"] == 2 + + def test_parse_response_http_200_mixed_threshold(self): + """Test response with detections both above and below threshold""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + mock_config.threshold = 0.7 + + client = DetectionsAPIClient(mock_config, "test-detector") + + response = [ + [ + { + "detection_type": "pii", + "detection": "email", + "score": 0.9, + "start": 0, + "end": 10, + "text": "a", + }, # Above + { + "detection_type": "pii", + "detection": "phone", + "score": 0.4, + "start": 20, + "end": 30, + "text": "b", + }, # Below + { + "detection_type": "pii", + "detection": "ssn", + "score": 0.8, + "start": 40, + "end": 50, + "text": "c", + }, # Above + ] + ] + + result = client.parse_response(response, 200) + + assert result.allowed is False + assert result.score == 0.9 # Highest of filtered + assert result.metadata["detection_count"] == 2 # Only above threshold + assert result.metadata["total_detections"] == 3 # All detections + + # Check passed flag + detections = result.metadata["detections"] + assert detections[0]["passed"] is False # score 0.9 >= 0.7 + assert detections[1]["passed"] is True # score 0.4 < 0.7 + assert detections[2]["passed"] is False # score 0.8 >= 0.7 + + def test_parse_response_http_404(self): + """Test HTTP 404 returns NOT_FOUND error""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + + client = DetectionsAPIClient(mock_config, "test-detector") + + result = client.parse_response({}, 404) + + assert result.allowed is False + assert result.score == 0.0 + assert result.label == "NOT_FOUND" + assert "not found" in result.reason.lower() + assert result.metadata["http_status"] == 404 + + def test_parse_response_http_422(self): + """Test HTTP 422 returns VALIDATION_ERROR""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + + client = DetectionsAPIClient(mock_config, "test-detector") + + result = client.parse_response({}, 422) + + assert result.allowed is False + assert result.score == 0.0 + assert result.label == "VALIDATION_ERROR" + assert "Invalid request" in result.reason + assert result.metadata["http_status"] == 422 + + def test_parse_response_http_500(self): + """Test HTTP 500 returns ERROR""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + + client = DetectionsAPIClient(mock_config, "test-detector") + + result = client.parse_response({}, 500) + + assert result.allowed is False + assert result.label == "ERROR" + assert "HTTP 500" in result.reason + assert result.metadata["http_status"] == 500 + + def test_parse_response_invalid_format_not_list(self): + """Test invalid response format (not a list) - returns INVALID_RESPONSE label""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + + client = DetectionsAPIClient(mock_config, "test-detector") + + # Response is dict instead of list + result = client.parse_response({"error": "bad format"}, 200) + + assert result.allowed is False + # CORRECTED: Actual implementation still uses INVALID_RESPONSE + assert result.label == "INVALID_RESPONSE" + assert "Invalid response format" in result.reason + assert result.metadata["response_type"] == "dict" + + def test_parse_response_empty_list(self): + """Test empty response list""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + + client = DetectionsAPIClient(mock_config, "test-detector") + + result = client.parse_response([], 200) + + assert result.allowed is True + assert result.label == "NONE" + assert result.reason == "No detections found" + + +class TestExtractDetectionsFromResponse: + """Tests for _extract_detections_from_response() helper""" + + def test_extract_from_valid_nested_array(self): + """Test extraction from valid nested array structure""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + + client = DetectionsAPIClient(mock_config, "test-detector") + + response = [[{"detection": "test1", "score": 0.9}, {"detection": "test2", "score": 0.8}]] + + detections = client._extract_detections_from_response(response) + + assert len(detections) == 2 + assert detections[0]["detection"] == "test1" + assert detections[1]["detection"] == "test2" + + def test_extract_from_empty_response(self): + """Test extraction from empty response""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + + client = DetectionsAPIClient(mock_config, "test-detector") + + detections = client._extract_detections_from_response([]) + + assert detections == [] + + def test_extract_from_empty_inner_array(self): + """Test extraction when inner array is empty""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + + client = DetectionsAPIClient(mock_config, "test-detector") + + response = [[]] # Empty inner array + + detections = client._extract_detections_from_response(response) + + assert detections == [] + + +class TestGetHighestScoreDetection: + """Tests for _get_highest_score_detection() helper""" + + def test_get_highest_from_multiple(self): + """Test finding highest score from multiple detections""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + + client = DetectionsAPIClient(mock_config, "test-detector") + + detections = [ + {"detection": "low", "score": 0.3}, + {"detection": "high", "score": 0.9}, + {"detection": "medium", "score": 0.6}, + ] + + highest = client._get_highest_score_detection(detections) + + assert highest["detection"] == "high" + assert highest["score"] == 0.9 + + def test_get_highest_from_single(self): + """Test with single detection""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + + client = DetectionsAPIClient(mock_config, "test-detector") + + detections = [{"detection": "only", "score": 0.7}] + + highest = client._get_highest_score_detection(detections) + + assert highest["detection"] == "only" + + def test_get_highest_from_empty_list(self): + """Test with empty detection list""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + + client = DetectionsAPIClient(mock_config, "test-detector") + + highest = client._get_highest_score_detection([]) + + assert highest == {} + + def test_get_highest_missing_score_field(self): + """Test handling detections missing score field""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + + client = DetectionsAPIClient(mock_config, "test-detector") + + detections = [ + {"detection": "no-score"}, # Missing score, defaults to 0.0 + {"detection": "has-score", "score": 0.5}, + ] + + highest = client._get_highest_score_detection(detections) + + assert highest["detection"] == "has-score" + + +class TestBuildReasonMessage: + """Tests for _build_reason_message() helper""" + + def test_build_reason_no_detections(self): + """Test reason message with no detections""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + + client = DetectionsAPIClient(mock_config, "test-detector") + + reason = client._build_reason_message([]) + + assert reason == "No detections found" + + def test_build_reason_single_detection(self): + """Test reason message with single detection""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + + client = DetectionsAPIClient(mock_config, "test-detector") + + detections = [{"detection_type": "pii", "detection": "EmailAddress", "score": 0.95}] + + reason = client._build_reason_message(detections) + + assert "Blocked by pii:EmailAddress" in reason + assert "score=0.95" in reason + + def test_build_reason_multiple_detections_same_type(self): + """Test reason message with multiple detections of same type""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + + client = DetectionsAPIClient(mock_config, "test-detector") + + detections = [ + {"detection_type": "pii", "detection": "email", "score": 0.9}, + {"detection_type": "pii", "detection": "phone", "score": 0.8}, + ] + + reason = client._build_reason_message(detections) + + assert "2 detections" in reason + assert "1 type(s)" in reason # Same type + assert "0.90" in reason # Highest score + + def test_build_reason_multiple_detections_different_types(self): + """Test reason message with multiple detection types""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + + client = DetectionsAPIClient(mock_config, "test-detector") + + detections = [ + {"detection_type": "pii", "detection": "email", "score": 0.9}, + {"detection_type": "toxicity", "detection": "hate", "score": 0.85}, + ] + + reason = client._build_reason_message(detections) + + assert "2 detections" in reason + assert "2 type(s)" in reason # Different types + + def test_build_reason_missing_fields_uses_unknown(self): + """Test handling detections with missing fields""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + + client = DetectionsAPIClient(mock_config, "test-detector") + + detections = [ + { + # Missing detection_type and detection fields + "score": 0.9 + } + ] + + reason = client._build_reason_message(detections) + + assert "unknown:unknown" in reason + + +class TestDetectIntegration: + """Integration tests for detect() method""" + + @pytest.mark.asyncio + async def test_detect_successful_flow(self): + """Test complete detect flow with successful detection""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com/api" + mock_config.detector_id = "test-detector-id" + mock_config.threshold = 0.5 + mock_config.timeout = 30 + mock_config.detector_params = {} + + client = DetectionsAPIClient(mock_config, "test-detector") + + # Mock _call_endpoint to return detection response + mock_response = [ + [ + { + "start": 0, + "end": 10, + "text": "test", + "detection_type": "toxicity", + "detection": "profanity", + "score": 0.95, + "evidence": {}, + "metadata": {}, + } + ] + ] + + with patch.object(client, "_call_endpoint", return_value=(mock_response, 200)): + result = await client.detect("test message") + + assert result.allowed is False + assert result.score == 0.95 + assert result.label == "toxicity:profanity" + + @pytest.mark.asyncio + async def test_detect_handles_exception(self): + """Test detect() handles exceptions via _handle_error()""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com/api" + mock_config.detector_id = "test-id" + + client = DetectionsAPIClient(mock_config, "test-detector") + + # Mock _call_endpoint to raise exception + with patch.object(client, "_call_endpoint", side_effect=Exception("Network error")): + result = await client.detect("test message") + + assert result.allowed is False + assert result.label == "ERROR" + assert "Network error" in result.reason + + @pytest.mark.asyncio + async def test_detect_with_custom_headers(self): + """Test detect() sends detector-id header""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com/api" + mock_config.detector_id = "custom-detector-123" + + client = DetectionsAPIClient(mock_config, "test-detector") + + mock_response = [[]] + + with patch.object(client, "_call_endpoint", return_value=(mock_response, 200)) as mock_call: + await client.detect("test") + + # Verify detector-id header was passed + call_kwargs = mock_call.call_args[1] + assert call_kwargs["headers"]["detector-id"] == "custom-detector-123" + + +class TestEdgeCases: + """Edge case tests for DetectionsAPIClient""" + + def test_parse_response_detection_missing_score(self): + """Test handling detection without score field""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + mock_config.threshold = 0.5 + + client = DetectionsAPIClient(mock_config, "test-detector") + + response = [ + [ + { + "start": 0, + "end": 10, + "text": "test", + "detection_type": "pii", + "detection": "email", + # Missing score - should default to 0.0 + } + ] + ] + + result = client.parse_response(response, 200) + + # Score defaults to 0.0, which is below threshold 0.5 + assert result.allowed is True + assert result.label == "BELOW_THRESHOLD" + + def test_parse_response_zero_threshold(self): + """Test with threshold set to 0.0 (everything blocks)""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + mock_config.threshold = 0.0 + + client = DetectionsAPIClient(mock_config, "test-detector") + + response = [[{"detection_type": "test", "detection": "low", "score": 0.01, "start": 0, "end": 1, "text": "a"}]] + + result = client.parse_response(response, 200) + + # Even tiny score exceeds 0.0 threshold + assert result.allowed is False + + def test_parse_response_threshold_one(self): + """Test with threshold set to 1.0 (nothing blocks unless perfect match)""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + mock_config.threshold = 1.0 + + client = DetectionsAPIClient(mock_config, "test-detector") + + response = [[{"detection_type": "test", "detection": "high", "score": 0.99, "start": 0, "end": 1, "text": "a"}]] + + result = client.parse_response(response, 200) + + # 0.99 < 1.0, so below threshold + assert result.allowed is True + assert result.label == "BELOW_THRESHOLD" + + def test_parse_response_exact_threshold_match(self): + """Test detection score exactly equals threshold (should block)""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + mock_config.threshold = 0.7 + + client = DetectionsAPIClient(mock_config, "test-detector") + + response = [[{"detection_type": "test", "detection": "exact", "score": 0.7, "start": 0, "end": 1, "text": "a"}]] + + result = client.parse_response(response, 200) + + # score >= threshold, so blocks + assert result.allowed is False + assert result.score == 0.7 + + +class TestMetadataConsistency: + """Tests for metadata structure consistency""" + + def test_metadata_structure_below_threshold(self): + """Test metadata has consistent structure for BELOW_THRESHOLD case""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + mock_config.threshold = 0.9 + + client = DetectionsAPIClient(mock_config, "test-detector") + + response = [[{"detection_type": "test", "detection": "low", "score": 0.3, "start": 0, "end": 1, "text": "a"}]] + + result = client.parse_response(response, 200) + + # Verify consistent metadata structure + assert "detection_count" in result.metadata + assert "total_detections" in result.metadata + assert "individual_scores" in result.metadata + assert "highest_detection" in result.metadata + assert "detections" in result.metadata + + # Verify passed flag exists + assert "passed" in result.metadata["detections"][0] + assert result.metadata["detections"][0]["passed"] is True + + def test_metadata_structure_blocked(self): + """Test metadata has consistent structure for BLOCKED case""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + mock_config.threshold = 0.5 + + client = DetectionsAPIClient(mock_config, "test-detector") + + response = [[{"detection_type": "pii", "detection": "email", "score": 0.9, "start": 0, "end": 10, "text": "a"}]] + + result = client.parse_response(response, 200) + + # Same fields as BELOW_THRESHOLD case + assert "detection_count" in result.metadata + assert "total_detections" in result.metadata + assert "individual_scores" in result.metadata + assert "highest_detection" in result.metadata + assert "detections" in result.metadata + assert "passed" in result.metadata["detections"][0] + assert result.metadata["detections"][0]["passed"] is False + + def test_metadata_individual_scores_includes_all(self): + """Test individual_scores includes ALL detections, not just filtered""" + mock_config = Mock() + mock_config.inference_endpoint = "http://test.com" + mock_config.detector_id = "test-id" + mock_config.threshold = 0.7 + + client = DetectionsAPIClient(mock_config, "test-detector") + + response = [ + [ + {"detection_type": "a", "detection": "1", "score": 0.9, "start": 0, "end": 1, "text": "a"}, # Above + {"detection_type": "b", "detection": "2", "score": 0.5, "start": 0, "end": 1, "text": "b"}, # Below + {"detection_type": "c", "detection": "3", "score": 0.8, "start": 0, "end": 1, "text": "c"}, # Above + ] + ] + + result = client.parse_response(response, 200) + + # Should include ALL 3 scores, not just the 2 above threshold + assert len(result.metadata["individual_scores"]) == 3 + assert result.metadata["individual_scores"] == [0.9, 0.5, 0.8] + + # But detection_count should be 2 (filtered) + assert result.metadata["detection_count"] == 2