diff --git a/docs/user-guides/community/openai.md b/docs/user-guides/community/openai.md new file mode 100644 index 000000000..6e6a43dff --- /dev/null +++ b/docs/user-guides/community/openai.md @@ -0,0 +1,126 @@ +## OpenAI API Compatibility for NeMo Guardrails + +NeMo Guardrails provides server-side compatibility with OpenAI API endpoints, enabling applications that use OpenAI clients to seamlessly integrate with NeMo Guardrails for adding guardrails to LLM interactions. Point your OpenAI client to `http://localhost:8000` (or your server URL) and use the standard `/v1/chat/completions` and `/v1/models` endpoint. + +## Feature Support Matrix + +The following table outlines which OpenAI API features are currently supported when using NeMo Guardrails: + +| Feature | Status | Notes | +| :------ | :----: | :---- | +| **Basic Chat Completion** | ✔ Supported | Full support for standard chat completions with guardrails applied | +| **Streaming Responses** | ✔ Supported | Server-Sent Events (SSE) streaming with `stream=true` | +| **List Models** | ✔ Supported | Returns available guardrails configurations as models | +| **Multimodal Input** | ✖ Unsupported | Support for text and image inputs (vision models) with guardrails but not yet OpenAI compatible | +| **Function Calling** | ✖ Unsupported | Not yet implemented; guardrails need structured output support | +| **Tools** | ✖ Unsupported | Related to function calling; requires action flow integration | +| **Response Format (JSON Mode)** | ✖ Unsupported | Structured output with guardrails requires additional validation logic | + +## Example Usage +Export the main model's base URL, engine, and API key as environment variables: + +``` +export MAIN_MODEL_BASE_URL="http://model-server/v1" +export MAIN_MODEL_ENGINE="openai" # or "nim", "vllm", etc. +export MAIN_MODEL_API_KEY="model-server-api-key" # or leave empty if not needed +``` + +**NOTE**: By default these values are: +* `MAIN_MODEL_BASE_URL`: `https://localhost:8000/v1` +* `MAIN_MODEL_ENGINE`: `nim` +* `MAIN_MODEL_API_KEY`: `None` + +## Basic Chat Completion + +``` +$ curl -X POST http://0.0.0.0:8000/v1/chat/completions \ + -H 'Accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "nemoguards", // Maps to config_id="nemoguards" + "messages": [ + { + "role": "user", + "content": "What can you do for me?" + } + ], + "max_tokens": 256, + "temperature": 1, + "top_p": 1 + }' +``` + +**NOTE**: You can also explicitly specify `config_id` if needed: + +``` +{ + "config_id": "my-config", + "messages": [...] +} +``` + +## Streaming Chat Completion + +``` +$ curl -X POST http://0.0.0.0:8000/v1/chat/completions \ + -H 'Accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "nemoguards", + "messages": [ + { + "role": "user", + "content": "What can you do for me?" + } + ], + "max_tokens": 256, + "stream": true, + "temperature": 1, + "top_p": 1, + }' +``` + +## List Available Models + +``` +$ curl -X GET http://0.0.0.0:8000/v1/models \ + -H 'Accept: application/json' +``` + +*Example output*: +``` +{ + "object": "list", + "data": [ + { + "id": "gpt-3.5-turbo-instruct", + "object": "model", + "created": 1234567890, + "owned_by": "nemo-guardrails", + "config_id": "abc" + } + ] +} +``` + +## Using with the OpenAI Python Client + +``` +from openai import OpenAI + +# Point to your NeMo Guardrails server +client = OpenAI( + api_key=None, + base_url="http://localhost:8000/v1" +) + +# Use the model field to specify your guardrails config +response = client.chat.completions.create( + model="nemoguards", # Your config ID + messages=[ + {"role": "user", "content": "Hello!"} + ] +) + +print(response.choices[0].message.content) +``` diff --git a/nemoguardrails/colang/v2_x/runtime/runtime.py b/nemoguardrails/colang/v2_x/runtime/runtime.py index 6980714bc..9cbbcb776 100644 --- a/nemoguardrails/colang/v2_x/runtime/runtime.py +++ b/nemoguardrails/colang/v2_x/runtime/runtime.py @@ -31,6 +31,7 @@ ColangSyntaxError, ) from nemoguardrails.colang.v2_x.runtime.flows import Event, FlowStatus +from nemoguardrails.colang.v2_x.runtime.serialization import json_to_state from nemoguardrails.colang.v2_x.runtime.statemachine import ( FlowConfig, InternalEvent, @@ -394,10 +395,13 @@ async def process_events( state = State(flow_states={}, flow_configs=self.flow_configs, rails_config=self.config) initialize_state(state) elif isinstance(state, dict): - # TODO: Implement dict to State conversion - raise NotImplementedError() - # if isinstance(state, dict): - # state = State.from_dict(state) + # Convert dict to State object + if state.get("version") == "2.x" and "state" in state: + # Handle the serialized state format from API calls + state = json_to_state(state["state"]) + else: + # TODO: Implement other dict to State conversion formats if needed + raise NotImplementedError("Unsupported state dict format") assert isinstance(state, State) assert state.main_flow_state is not None diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 923d019f1..81c93ad9a 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -95,6 +95,10 @@ class ModelCacheConfig(BaseModel): ) +class ModelRegistry(BaseModel): + pass + + class Model(BaseModel): """Configuration of a model used by the rails engine. diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 0833710fa..f8f667a62 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -500,6 +500,11 @@ def _init_llms(self): if not self.llm: self.llm = llm_model self.runtime.register_action_param("llm", self.llm) + self._configure_main_llm_streaming( + self.llm, + model_name=llm_config.model, + provider_name=llm_config.engine, + ) else: model_name = f"{llm_config.type}_llm" if not hasattr(self, model_name): diff --git a/nemoguardrails/server/api.py b/nemoguardrails/server/api.py index 658cffd01..99c4defa3 100644 --- a/nemoguardrails/server/api.py +++ b/nemoguardrails/server/api.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import asyncio import contextvars import importlib.util @@ -20,24 +21,32 @@ import os.path import re import time -import warnings +import uuid from contextlib import asynccontextmanager -from typing import Any, Callable, List, Optional +from typing import Any, AsyncIterator, Callable, List, Optional, Union from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel, Field, root_validator, validator +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage +from pydantic import BaseModel, Field, ValidationError, root_validator, validator from starlette.responses import StreamingResponse from starlette.staticfiles import StaticFiles from nemoguardrails import LLMRails, RailsConfig, utils -from nemoguardrails.rails.llm.options import ( - GenerationLog, - GenerationOptions, - GenerationResponse, -) +from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse from nemoguardrails.server.datastore.datastore import DataStore -from nemoguardrails.streaming import StreamingHandler +from nemoguardrails.server.schemas.openai import ( + GuardrailsChatCompletion, + GuardrailsModel, + GuardrailsModelsResponse, +) +from nemoguardrails.server.schemas.utils import ( + create_error_chat_completion, + extract_bot_message_from_response, + format_streaming_chunk_as_sse, + generation_response_to_chat_completion, +) logging.basicConfig(level=logging.INFO) log = logging.getLogger(__name__) @@ -228,16 +237,60 @@ class RequestBody(BaseModel): default=None, description="A state object that should be used to continue the interaction.", ) + # Standard OpenAI completion parameters + model: str = Field( + default="main", + description="The model to use for chat completion. Maps to the main model in the config.", + ) + max_tokens: Optional[int] = Field( + default=None, + description="The maximum number of tokens to generate.", + ) + temperature: Optional[float] = Field( + default=None, + description="Sampling temperature to use.", + ) + top_p: Optional[float] = Field( + default=None, + description="Top-p sampling parameter.", + ) + stop: Optional[str] = Field( + default=None, + description="Stop sequences.", + ) + presence_penalty: Optional[float] = Field( + default=None, + description="Presence penalty parameter.", + ) + frequency_penalty: Optional[float] = Field( + default=None, + description="Frequency penalty parameter.", + ) + function_call: Optional[dict] = Field( + default=None, + description="Function call parameter.", + ) + logit_bias: Optional[dict] = Field( + default=None, + description="Logit bias parameter.", + ) + log_probs: Optional[bool] = Field( + default=None, + description="Log probabilities parameter.", + ) @root_validator(pre=True) def ensure_config_id(cls, data: Any) -> Any: if isinstance(data, dict): if data.get("config_id") is not None and data.get("config_ids") is not None: raise ValueError("Only one of config_id or config_ids should be specified") - if data.get("config_id") is None and data.get("config_ids") is not None: - data["config_id"] = None + + # Map OpenAI 'model' field to 'config_id' if config_id is not provided if data.get("config_id") is None and data.get("config_ids") is None: - warnings.warn("No config_id or config_ids provided, using default config_id") + model = data.get("model") + if model and model != "main": + # Use model as config_id for OpenAI compatibility + data["config_id"] = model return data @validator("config_ids", pre=True, always=True) @@ -248,21 +301,72 @@ def ensure_config_ids(cls, v, values): return v -class ResponseBody(BaseModel): - messages: Optional[List[dict]] = Field(default=None, description="The new messages in the conversation") - llm_output: Optional[dict] = Field( - default=None, - description="Contains any additional output coming from the LLM.", - ) - output_data: Optional[dict] = Field( - default=None, - description="The output data, i.e. a dict with the values corresponding to the `output_vars`.", - ) - log: Optional[GenerationLog] = Field(default=None, description="Additional logging information.") - state: Optional[dict] = Field( - default=None, - description="A state object that should be used to continue the interaction in the future.", - ) +@app.get( + "/v1/models", + response_model=GuardrailsModelsResponse, + summary="List available models", + description="Lists the currently available models, mapping guardrails configurations to OpenAI-compatible model format.", +) +async def get_models(): + """Returns the list of available models (guardrails configurations) in OpenAI-compatible format.""" + + # Use the same logic as get_rails_configs to find available configurations + if app.single_config_mode: + config_ids = [app.single_config_id] if app.single_config_id else [] + + else: + config_ids = [ + f + for f in os.listdir(app.rails_config_path) + if os.path.isdir(os.path.join(app.rails_config_path, f)) + and f[0] != "." + and f[0] != "_" + # Filter out all the configs for which there is no `config.yml` file. + and ( + os.path.exists(os.path.join(app.rails_config_path, f, "config.yml")) + or os.path.exists(os.path.join(app.rails_config_path, f, "config.yaml")) + ) + ] + + models = [] + for config_id in config_ids: + try: + # Load the RailsConfig to extract model information + if app.single_config_mode: + config_path = app.rails_config_path + else: + config_path = os.path.join(app.rails_config_path, config_id) + + rails_config = RailsConfig.from_path(config_path) + # Extract all models from this config + config_models = rails_config.models + + if len(config_models) == 0: + guardrails_model = GuardrailsModel( + id=config_id, + object="model", + created=int(time.time()), + owned_by="nemo-guardrails", + config_id=config_id, + ) + models.append(guardrails_model) + else: + for model in config_models: + # Only include models with a model name + if model.model: + guardrails_model = GuardrailsModel( + id=model.model, + object="model", + created=int(time.time()), + owned_by="nemo-guardrails", + config_id=config_id, + ) + models.append(guardrails_model) + except Exception as ex: + log.warning(f"Could not load model info for config {config_id}: {ex}") + continue + + return GuardrailsModelsResponse(data=models) @app.get( @@ -305,6 +409,14 @@ def _generate_cache_key(config_ids: List[str]) -> str: return "-".join((config_ids)) # remove sorted +def _get_main_model_name(rails_config: RailsConfig) -> Optional[str]: + """Extracts the main model name from a RailsConfig.""" + main_models = [m for m in rails_config.models if m.type == "main"] + if main_models and main_models[0].model: + return main_models[0].model + return None + + def _get_rails(config_ids: List[str]) -> LLMRails: """Returns the rails instance for the given config id.""" @@ -355,9 +467,85 @@ def _get_rails(config_ids: List[str]) -> LLMRails: return llm_rails +class ChunkErrorMetadata(BaseModel): + message: str + type: str + param: str + code: str + + +class ChunkError(BaseModel): + error: ChunkErrorMetadata + + +async def _format_streaming_response( + stream_iterator: AsyncIterator[Union[str, dict]], model_name: str +) -> AsyncIterator[str]: + """ + Format streaming chunks from LLMRails.stream_async() as SSE events. + + Args: + stream_iterator: AsyncIterator from stream_async() that yields str or dict chunks + model_name: The model name to include in the chunks + + Yields: + SSE-formatted strings (data: {...}\n\n) + """ + # Use "unknown" as default if model_name is None + model = model_name or "unknown" + + try: + async for chunk in stream_iterator: + # Format the chunk as SSE using the utility function + processed_chunk = process_chunk(chunk) + if isinstance(processed_chunk, ChunkError): + # Yield the error and stop streaming + yield f"data: {json.dumps(processed_chunk.model_dump())}\n\n" + return + else: + yield format_streaming_chunk_as_sse(processed_chunk, model) + + finally: + # Always send [DONE] event when stream ends + yield "data: [DONE]\n\n" + + +def process_chunk(chunk: Any) -> Union[Any, ChunkError]: + """ + Processes a single chunk from the stream. + + Args: + chunk: A single chunk from the stream (can be str, dict, or other type). + model: The model name (not used in processing but kept for signature consistency). + + Returns: + Union[Any, StreamingError]: StreamingError instance for errors or the original chunk. + """ + # Convert chunk to string for JSON parsing if needed + chunk_str = chunk if isinstance(chunk, str) else json.dumps(chunk) if isinstance(chunk, dict) else str(chunk) + + try: + validated_data = ChunkError.model_validate_json(chunk_str) + return validated_data # Return the StreamingError instance directly + except ValidationError: + # Not an error, just a normal token + pass + except json.JSONDecodeError: + # Invalid JSON format, treat as normal token + pass + except Exception as e: + log.warning( + f"Unexpected error processing stream chunk: {type(e).__name__}: {str(e)}", + extra={"chunk": chunk_str}, + ) + + # Return the original chunk + return chunk + + @app.post( "/v1/chat/completions", - response_model=ResponseBody, + response_model=GuardrailsChatCompletion, response_model_exclude_none=True, ) async def chat_completion(body: RequestBody, request: Request): @@ -375,6 +563,7 @@ async def chat_completion(body: RequestBody, request: Request): # Use Request config_ids if set, otherwise use the FastAPI default config. # If neither is available we can't generate any completions as we have no config_id config_ids = body.config_ids + if not config_ids: if app.default_config_id: config_ids = [app.default_config_id] @@ -383,19 +572,20 @@ async def chat_completion(body: RequestBody, request: Request): try: llm_rails = _get_rails(config_ids) + except ValueError as ex: log.exception(ex) - return ResponseBody( - messages=[ - { - "role": "assistant", - "content": f"Could not load the {config_ids} guardrails configuration. " - f"An internal error has occurred.", - } - ] + return create_error_chat_completion( + model=config_ids[0] if config_ids else "unknown", + error_message=f"Could not load the {config_ids} guardrails configuration. An internal error has occurred.", + config_id=config_ids[0] if config_ids else None, ) try: + main_model_name = _get_main_model_name(llm_rails.config) + if main_model_name is None: + main_model_name = config_ids[0] if config_ids else "unknown" + messages = body.messages or [] if body.context: messages.insert(0, {"role": "context", "content": body.context}) @@ -406,16 +596,12 @@ async def chat_completion(body: RequestBody, request: Request): if body.thread_id: if datastore is None: raise RuntimeError("No DataStore has been configured.") - # We make sure the `thread_id` meets the minimum complexity requirement. if len(body.thread_id) < 16: - return ResponseBody( - messages=[ - { - "role": "assistant", - "content": "The `thread_id` must have a minimum length of 16 characters.", - } - ] + return create_error_chat_completion( + model=main_model_name, + error_message="The `thread_id` must have a minimum length of 16 characters.", + config_id=config_ids[0] if config_ids else None, ) # Fetch the existing thread messages. For easier management, we prepend @@ -426,56 +612,82 @@ async def chat_completion(body: RequestBody, request: Request): # And prepend them. messages = thread_messages + messages + generation_options = body.options + + # Initialize llm_params if not already set + if generation_options.llm_params is None: + generation_options.llm_params = {} + + # Set OpenAI-compatible parameters in llm_params + if body.max_tokens: + generation_options.llm_params["max_tokens"] = body.max_tokens + if body.temperature is not None: + generation_options.llm_params["temperature"] = body.temperature + if body.top_p is not None: + generation_options.llm_params["top_p"] = body.top_p + if body.stop: + generation_options.llm_params["stop"] = body.stop + if body.presence_penalty is not None: + generation_options.llm_params["presence_penalty"] = body.presence_penalty + if body.frequency_penalty is not None: + generation_options.llm_params["frequency_penalty"] = body.frequency_penalty if body.stream and llm_rails.config.streaming_supported and llm_rails.main_llm_supports_streaming: - # Create the streaming handler instance - streaming_handler = StreamingHandler() - - # Start the generation - asyncio.create_task( - llm_rails.generate_async( - messages=messages, - streaming_handler=streaming_handler, - options=body.options, - state=body.state, - ) + # Use stream_async for streaming with output rails support + stream_iterator = llm_rails.stream_async( + messages=messages, + options=generation_options, + state=body.state, ) - # TODO: Add support for thread_ids in streaming mode - - return StreamingResponse(streaming_handler) + return StreamingResponse( + _format_streaming_response(stream_iterator, model_name=main_model_name), + media_type="text/event-stream", + ) else: - res = await llm_rails.generate_async(messages=messages, options=body.options, state=body.state) + res = await llm_rails.generate_async(messages=messages, options=generation_options, state=body.state) - if isinstance(res, GenerationResponse): - bot_message_content = res.response[0] - # Ensure bot_message is always a dict - if isinstance(bot_message_content, str): - bot_message = {"role": "assistant", "content": bot_message_content} - else: - bot_message = bot_message_content - else: - assert isinstance(res, dict) - bot_message = res + # Extract bot message for thread storage if needed + bot_message = extract_bot_message_from_response(res) # If we're using threads, we also need to update the data before returning # the message. if body.thread_id and datastore is not None and datastore_key is not None: await datastore.set(datastore_key, json.dumps(messages + [bot_message])) - result = ResponseBody(messages=[bot_message]) - - # If we have additional GenerationResponse fields, we return as well + # Build the response with OpenAI-compatible format using utility function if isinstance(res, GenerationResponse): - result.llm_output = res.llm_output - result.output_data = res.output_data - result.log = res.log - result.state = res.state - - return result + return generation_response_to_chat_completion( + response=res, + model=main_model_name, + config_id=config_ids[0] if config_ids else None, + ) + else: + # For dict responses, convert to basic chat completion + return GuardrailsChatCompletion( + id=f"chatcmpl-{uuid.uuid4()}", + object="chat.completion", + created=int(time.time()), + model=main_model_name, + choices=[ + Choice( + index=0, + message=ChatCompletionMessage( + role="assistant", + content=bot_message.get("content", ""), + ), + finish_reason="stop", + logprobs=None, + ) + ], + ) except Exception as ex: log.exception(ex) - return ResponseBody(messages=[{"role": "assistant", "content": "Internal server error."}]) + return create_error_chat_completion( + model=config_ids[0] if config_ids else "unknown", + error_message="Internal server error", + config_id=config_ids[0] if config_ids else None, + ) # By default, there are no challenges diff --git a/nemoguardrails/server/schemas/openai.py b/nemoguardrails/server/schemas/openai.py new file mode 100644 index 000000000..7d40aafa0 --- /dev/null +++ b/nemoguardrails/server/schemas/openai.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OpenAI API schema definitions for the NeMo Guardrails server.""" + +import os +from typing import List, Optional + +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.model import Model +from pydantic import BaseModel, Field + + +class GuardrailsChatCompletion(ChatCompletion): + """OpenAI API response body with NeMo-Guardrails extensions.""" + + config_id: Optional[str] = Field( + default=None, + description="The guardrails configuration ID associated with this response.", + ) + state: Optional[dict] = Field(default=None, description="State object for continuing the conversation.") + llm_output: Optional[dict] = Field(default=None, description="Additional LLM output data.") + output_data: Optional[dict] = Field(default=None, description="Additional output data.") + log: Optional[dict] = Field(default=None, description="Generation log data.") + + +class GuardrailsModel(Model): + """OpenAI API model with NeMo-Guardrails extensions.""" + + config_id: Optional[str] = Field( + default=None, + description="[NeMo Guardrails extension] The guardrails configuration ID associated with this model.", + ) + engine: Optional[str] = Field( + default_factory=lambda: os.getenv("MAIN_MODEL_ENGINE", "nim"), + description="[NeMo Guardrails extension] The engine associated with this model.", + ) + base_url: Optional[str] = Field( + default_factory=lambda: os.getenv("MAIN_MODEL_BASE_URL", "https://localhost:8000/v1"), + description="[NeMo Guardrails extension] The base URL this model serves on.", + ) + api_key_env_var: Optional[str] = Field( + default_factory=lambda: os.getenv("MAIN_MODEL_API_KEY", None), + description="[NeMo Guardrails extension] This model's API key.", + ) + + +class GuardrailsModelsResponse(BaseModel): + """OpenAI API models list response with NeMo-Guardrails extensions.""" + + object: str = Field(default="list", description="The object type, which is always 'list'.") + data: List[GuardrailsModel] = Field(description="The list of models.") diff --git a/nemoguardrails/server/schemas/utils.py b/nemoguardrails/server/schemas/utils.py new file mode 100644 index 000000000..86f950f62 --- /dev/null +++ b/nemoguardrails/server/schemas/utils.py @@ -0,0 +1,258 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for converting between Guardrails and OpenAI API formats.""" + +import json +import time +import uuid +from typing import Any, Dict, Optional, Tuple, Union + +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage + +from nemoguardrails.rails.llm.options import GenerationResponse +from nemoguardrails.server.schemas.openai import GuardrailsChatCompletion + + +def extract_bot_message_from_response( + response: Union[str, dict, GenerationResponse, Tuple[dict, dict]], +) -> Dict[str, Any]: + """ + Extract the bot message from generate_async response. + + Args: + response: Response from LLMRails.generate_async() which can be: + - str: Direct text response + - dict: Message dict + - GenerationResponse: Full response object + - Tuple[dict, dict]: (message, state) tuple + + Returns: + A dictionary with at least 'role' and 'content' keys + """ + if isinstance(response, GenerationResponse): + bot_message_content = response.response[0] + # Ensure bot_message is always a dict + if isinstance(bot_message_content, str): + return {"role": "assistant", "content": bot_message_content} + else: + return bot_message_content + elif isinstance(response, str): + # Direct string response + return {"role": "assistant", "content": response} + elif isinstance(response, tuple): + # Tuple of (message, state) + bot_message = response[0] + if isinstance(bot_message, dict): + return bot_message + else: + return {"role": "assistant", "content": str(bot_message)} + else: + # Already a dict + return response + + +def generation_response_to_chat_completion( + response: GenerationResponse, + model: str, + config_id: Optional[str] = None, +) -> GuardrailsChatCompletion: + """ + Convert a GenerationResponse to an OpenAI-compatible GuardrailsChatCompletion. + + Args: + response: The GenerationResponse from LLMRails.generate_async() + model: The model name to include in the response + config_id: Optional guardrails configuration ID + + Returns: + A GuardrailsChatCompletion instance compatible with OpenAI API format + """ + bot_message = extract_bot_message_from_response(response) + + # Convert log to dict if present (for JSON serialization) + log_dict = None + if response.log: + if hasattr(response.log, "model_dump"): + log_dict = response.log.model_dump() + elif hasattr(response.log, "dict"): + log_dict = response.log.dict() + elif isinstance(response.log, dict): + log_dict = response.log + else: + # Fallback: try to convert to dict + try: + log_dict = dict(response.log) + except (TypeError, ValueError): + # If conversion fails, skip the log + log_dict = None + + return GuardrailsChatCompletion( + id=f"chatcmpl-{uuid.uuid4()}", + object="chat.completion", + created=int(time.time()), + model=model, + choices=[ + Choice( + index=0, + message=ChatCompletionMessage( + role="assistant", + content=bot_message.get("content", ""), + ), + finish_reason="stop", + logprobs=None, + ) + ], + config_id=config_id, + llm_output=response.llm_output, + output_data=response.output_data, + log=log_dict, + state=response.state, + ) + + +def create_error_chat_completion( + model: str, + error_message: str, + config_id: Optional[str] = None, +) -> GuardrailsChatCompletion: + """ + Create an error response in GuardrailsChatCompletion format. + + Args: + model: The model name to include in the response + error_message: The error message to return + config_id: Optional guardrails configuration ID + + Returns: + A GuardrailsChatCompletion instance with the error message + """ + return GuardrailsChatCompletion( + id=f"chatcmpl-{uuid.uuid4()}", + object="chat.completion", + created=int(time.time()), + model=model, + choices=[ + Choice( + index=0, + message=ChatCompletionMessage( + role="assistant", + content=error_message, + ), + finish_reason="stop", + logprobs=None, + ) + ], + config_id=config_id, + ) + + +def format_streaming_chunk( + chunk: Any, + model: str, + chunk_id: Optional[str] = None, +) -> Dict[str, Any]: + """ + Format a streaming chunk into OpenAI chat completion chunk format. + + Args: + chunk: The chunk from LLMRails.stream_async() (can be dict, str, or other type) + model: The model name to include in the chunk + chunk_id: Optional ID for the chunk (generates UUID if not provided) + + Returns: + A dictionary in OpenAI streaming chunk format + """ + if chunk_id is None: + chunk_id = f"chatcmpl-{uuid.uuid4()}" + + # Determine the payload format based on chunk type + if isinstance(chunk, dict): + # If chunk is a dict, wrap it in OpenAI chunk format with delta + return { + "id": chunk_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "delta": chunk, + "index": 0, + "finish_reason": None, + } + ], + } + elif isinstance(chunk, str): + try: + # Try parsing as JSON - if it parses, it might be a pre-formed payload + payload = json.loads(chunk) + # Ensure it has the required fields + if "id" not in payload: + payload["id"] = chunk_id + if "model" not in payload: + payload["model"] = model + return payload + except (json.JSONDecodeError, ValueError): + # treat as plain text content token + return { + "id": chunk_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "delta": {"content": chunk}, + "index": 0, + "finish_reason": None, + } + ], + } + else: + # For any other type, treat as plain content + return { + "id": chunk_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "delta": {"content": str(chunk)}, + "index": 0, + "finish_reason": None, + } + ], + } + + +def format_streaming_chunk_as_sse( + chunk: Any, + model: str, + chunk_id: Optional[str] = None, +) -> str: + """ + Format a streaming chunk as a Server-Sent Event (SSE) data line. + + Args: + chunk: The chunk from StreamingHandler + model: The model name to include in the chunk + chunk_id: Optional ID for the chunk + + Returns: + A formatted SSE string (e.g., "data: {...}\\n\\n") + """ + payload = format_streaming_chunk(chunk, model, chunk_id) + data = json.dumps(payload, ensure_ascii=False) + return f"data: {data}\n\n" diff --git a/nemoguardrails/streaming.py b/nemoguardrails/streaming.py index 7cf8ac7c3..7bca4758c 100644 --- a/nemoguardrails/streaming.py +++ b/nemoguardrails/streaming.py @@ -176,15 +176,34 @@ async def _process( chunk: Union[str, object], generation_info: Optional[Dict[str, Any]] = None, ): - """Process a chunk of text. + """Process a chunk of text or dict. If we're in buffering mode, record the text. Otherwise, update the full completion, check for stop tokens, and enqueue the chunk. + Dict chunks bypass completion tracking and go directly to the queue. """ if self.include_generation_metadata and generation_info: self.current_generation_info = generation_info + # Dict chunks bypass buffering and completion tracking + if isinstance(chunk, dict): + if self.pipe_to: + asyncio.create_task(self.pipe_to.push_chunk(chunk)) + else: + if self.include_generation_metadata: + await self.queue.put( + { + "text": chunk, + "generation_info": ( + self.current_generation_info.copy() if self.current_generation_info else {} + ), + } + ) + else: + await self.queue.put(chunk) + return + if self.enable_buffer: if chunk is not END_OF_STREAM: self.buffer += chunk if chunk is not None else "" @@ -254,10 +273,28 @@ async def _process( async def push_chunk( self, - chunk: Union[str, GenerationChunk, AIMessageChunk, ChatGenerationChunk, None], + chunk: Union[ + str, + # dict, + GenerationChunk, + AIMessageChunk, + ChatGenerationChunk, + None, + object, + ], generation_info: Optional[Dict[str, Any]] = None, ): - """Push a new chunk to the stream.""" + """Push a new chunk to the stream. + + Args: + chunk: The chunk to push. Can be: + - str: Plain text content + - dict: Dictionary with fields like role, content, etc. + - GenerationChunk/AIMessageChunk/ChatGenerationChunk: LangChain chunk types + - None: Signals end of stream (converted to END_OF_STREAM) + - object: END_OF_STREAM sentinel + generation_info: Optional metadata about the generation + """ # if generation_info is not explicitly passed, # try to get it from the chunk itself if it's a GenerationChunk or ChatGenerationChunk @@ -281,6 +318,9 @@ async def push_chunk( elif isinstance(chunk, str): # empty string is a valid chunk and should be processed normally pass + elif isinstance(chunk, dict): + # plain dict chunks are allowed for OpenAI-compatible streaming + pass else: raise Exception(f"Unsupported chunk type: {chunk.__class__.__name__}") @@ -291,6 +331,11 @@ async def push_chunk( if self.include_generation_metadata and generation_info: self.current_generation_info = generation_info + # Dict chunks bypass prefix/suffix processing and go directly to _process + if isinstance(chunk, dict): + await self._process(chunk, generation_info) + return + # Process prefix: accumulate until the expected prefix is received, then remove it. if self.prefix: if chunk is not None and chunk is not END_OF_STREAM: diff --git a/poetry.lock b/poetry.lock index f9976e9aa..23aecd69e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "accessible-pygments" @@ -22,7 +22,7 @@ tests = ["hypothesis", "pytest"] name = "aiofiles" version = "24.1.0" description = "File support for asyncio." -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5"}, @@ -935,7 +935,7 @@ files = [ name = "distro" version = "1.9.0" description = "Distro - an OS platform information API" -optional = true +optional = false python-versions = ">=3.6" files = [ {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, @@ -1841,7 +1841,7 @@ i18n = ["Babel (>=2.7)"] name = "jiter" version = "0.10.0" description = "Fast iterable JSON parser." -optional = true +optional = false python-versions = ">=3.9" files = [ {file = "jiter-0.10.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:cd2fb72b02478f06a900a5782de2ef47e0396b3e1f7d5aba30daeb1fce66f303"}, @@ -3040,7 +3040,7 @@ sympy = "*" name = "openai" version = "1.102.0" description = "The official Python library for the openai API" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "openai-1.102.0-py3-none-any.whl", hash = "sha256:d751a7e95e222b5325306362ad02a7aa96e1fab3ed05b5888ce1c7ca63451345"}, @@ -6596,7 +6596,7 @@ files = [ cffi = ["cffi (>=1.17)"] [extras] -all = ["aiofiles", "fast-langdetect", "google-cloud-language", "langchain-nvidia-ai-endpoints", "langchain-openai", "numpy", "numpy", "numpy", "numpy", "opentelemetry-api", "presidio-analyzer", "presidio-anonymizer", "streamlit", "tqdm", "yara-python"] +all = ["fast-langdetect", "google-cloud-language", "langchain-nvidia-ai-endpoints", "langchain-openai", "numpy", "numpy", "numpy", "numpy", "opentelemetry-api", "presidio-analyzer", "presidio-anonymizer", "streamlit", "tqdm", "yara-python"] eval = ["numpy", "numpy", "numpy", "numpy", "streamlit", "tornado", "tqdm"] gcp = ["google-cloud-language"] jailbreak = ["yara-python"] @@ -6604,9 +6604,9 @@ multilingual = ["fast-langdetect"] nvidia = ["langchain-nvidia-ai-endpoints"] openai = ["langchain-openai"] sdd = ["presidio-analyzer", "presidio-anonymizer"] -tracing = ["aiofiles", "opentelemetry-api"] +tracing = ["opentelemetry-api"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.14" -content-hash = "5f621add3bdfe92f78c38e14702f22cef7adb3a98b6ca6e494bb44ed834bdd97" +content-hash = "d98bb73f1ab645044c8e92788a70518ae054ce8d53245beb65e07c532b6acfc0" diff --git a/pyproject.toml b/pyproject.toml index 3e864afef..79f8ad922 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,10 +71,11 @@ starlette = ">=0.49.1" typer = ">=0.8" uvicorn = ">=0.23" watchdog = ">=3.0.0," +aiofiles = ">=24.1.0" +openai = ">=1.0.0, <2.0.0" # tracing opentelemetry-api = { version = ">=1.27.0,<2.0.0", optional = true } -aiofiles = { version = ">=24.1.0", optional = true } # openai langchain-openai = { version = ">=0.1.0", optional = true } @@ -113,7 +114,7 @@ sdd = ["presidio-analyzer", "presidio-anonymizer"] eval = ["tqdm", "numpy", "streamlit", "tornado"] openai = ["langchain-openai"] gcp = ["google-cloud-language"] -tracing = ["opentelemetry-api", "aiofiles"] +tracing = ["opentelemetry-api"] nvidia = ["langchain-nvidia-ai-endpoints"] jailbreak = ["yara-python"] multilingual = ["fast-langdetect"] @@ -129,7 +130,6 @@ all = [ "langchain-openai", "google-cloud-language", "opentelemetry-api", - "aiofiles", "langchain-nvidia-ai-endpoints", "yara-python", "fast-langdetect", diff --git a/tests/test_api.py b/tests/test_api.py index b6619fe7a..821e9c9d2 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -13,13 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os +from typing import AsyncIterator, Union +from unittest.mock import patch import pytest from fastapi.testclient import TestClient from nemoguardrails.server import api -from nemoguardrails.server.api import RequestBody +from nemoguardrails.server.api import RequestBody, _format_streaming_response + +LIVE_TEST_MODE = os.environ.get("LIVE_TEST_MODE") or os.environ.get("TEST_LIVE_MODE") client = TestClient(api.app) @@ -41,7 +46,52 @@ def test_get(): assert len(result) > 0 -@pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.") +def test_get_models_default_env_vars(): + """Test the OpenAI-compatible /v1/models endpoint.""" + response = client.get("/v1/models") + assert response.status_code == 200 + + result = response.json() + + # Check OpenAI models list format + assert result["object"] == "list" + assert "data" in result + assert len(result["data"]) > 0 + + # Check each model has the required OpenAI format + for model in result["data"]: + assert "id" in model + assert "config_id" in model + assert model["object"] == "model" + assert "created" in model + assert model["owned_by"] == "nemo-guardrails" + assert model["engine"] == "nim" + assert model["base_url"] == "https://localhost:8000/v1" + assert model["api_key_env_var"] is None + + +def test_get_models_with_custom_env_vars(): + with patch.dict( + os.environ, + { + "MAIN_MODEL_ENGINE": "custom-engine", + "MAIN_MODEL_BASE_URL": "https://custom-api.example.com/v1", + "MAIN_MODEL_API_KEY": "custom-api-key", + }, + ): + response = client.get("/v1/models") + assert response.status_code == 200 + result = response.json() + for model in result["data"]: + assert model["engine"] == "custom-engine" + assert model["base_url"] == "https://custom-api.example.com/v1" + assert model["api_key_env_var"] == "custom-api-key" + + +@pytest.mark.skipif( + not LIVE_TEST_MODE, + reason="This test requires LIVE_TEST_MODE or TEST_LIVE_MODE environment variable to be set for live testing", +) def test_chat_completion(): response = client.post( "/v1/chat/completions", @@ -57,11 +107,20 @@ def test_chat_completion(): ) assert response.status_code == 200 res = response.json() - assert len(res["messages"]) == 1 - assert res["messages"][0]["content"] + # Check OpenAI-compatible response structure + assert res["object"] == "chat.completion" + assert "id" in res + assert "created" in res + assert "model" in res + assert len(res["choices"]) == 1 + assert res["choices"][0]["message"]["content"] + assert res["choices"][0]["message"]["role"] == "assistant" -@pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.") +@pytest.mark.skipif( + not LIVE_TEST_MODE, + reason="This test requires LIVE_TEST_MODE or TEST_LIVE_MODE environment variable to be set for live testing", +) def test_chat_completion_with_default_configs(): api.set_default_config_id("general") @@ -78,8 +137,14 @@ def test_chat_completion_with_default_configs(): ) assert response.status_code == 200 res = response.json() - assert len(res["messages"]) == 1 - assert res["messages"][0]["content"] + # Check OpenAI-compatible response structure + assert res["object"] == "chat.completion" + assert "id" in res + assert "created" in res + assert "model" in res + assert len(res["choices"]) == 1 + assert res["choices"][0]["message"]["content"] + assert res["choices"][0]["message"]["role"] == "assistant" def test_request_body_validation(): @@ -113,6 +178,29 @@ def test_request_body_validation(): assert request_body.config_ids is None +def test_openai_model_field_mapping(): + """Test OpenAI-compatible model field mapping to config_id.""" + + # Test model field maps to config_id + data = { + "model": "test_model", + "messages": [{"role": "user", "content": "Hello"}], + } + request_body = RequestBody.model_validate(data) + assert request_body.model == "test_model" + + # Test model and config_id both provided (config_id takes precedence) + data = { + "model": "test_model", + "config_id": "test_config", + "messages": [{"role": "user", "content": "Hello"}], + } + request_body = RequestBody.model_validate(data) + assert request_body.model == "test_model" + assert request_body.config_id == "test_config" + assert request_body.config_ids == ["test_config"] + + def test_request_body_state(): """Test RequestBody state handling.""" data = { @@ -134,6 +222,7 @@ def test_request_body_messages(): ], } request_body = RequestBody.model_validate(data) + assert request_body.messages is not None assert len(request_body.messages) == 2 data = { @@ -141,4 +230,227 @@ def test_request_body_messages(): "messages": [{"content": "Hello"}], } request_body = RequestBody.model_validate(data) + assert request_body.messages is not None assert len(request_body.messages) == 1 + + +async def _create_test_stream(chunks: list) -> AsyncIterator[Union[str, dict]]: + """Helper to create an async iterator for testing.""" + for chunk in chunks: + yield chunk + + +@pytest.mark.asyncio +async def test_openai_sse_format_basic_chunks(): + """Test basic string chunks are properly formatted as SSE events.""" + # Create a test stream with string chunks + stream = _create_test_stream(["Hello ", "world"]) + + # Collect yielded SSE messages + collected = [] + async for b in _format_streaming_response(stream, model_name=None): + collected.append(b) + + # We expect three messages: two data: {json}\n\n events and final data: [DONE]\n\n + assert len(collected) == 3 + # First two are JSON SSE events + evt1 = collected[0] + evt2 = collected[1] + done = collected[2] + + assert evt1.startswith("data: ") + j1 = json.loads(evt1[len("data: ") :].strip()) + assert j1["object"] == "chat.completion.chunk" + assert j1["choices"][0]["delta"]["content"] == "Hello " + + assert evt2.startswith("data: ") + j2 = json.loads(evt2[len("data: ") :].strip()) + assert j2["choices"][0]["delta"]["content"] == "world" + + assert done == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_with_model_name(): + """Test that model name is properly included in the response.""" + stream = _create_test_stream(["Test"]) + collected = [] + + async for b in _format_streaming_response(stream, model_name="gpt-4"): + collected.append(b) + + assert len(collected) == 2 + evt = collected[0] + j = json.loads(evt[len("data: ") :].strip()) + assert j["model"] == "gpt-4" + assert j["choices"][0]["delta"]["content"] == "Test" + assert collected[1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_with_dict_chunk(): + """Test that dict chunks with role and content are properly formatted.""" + stream = _create_test_stream([{"role": "assistant", "content": "Hi!"}]) + collected = [] + + async for b in _format_streaming_response(stream, model_name=None): + collected.append(b) + + assert len(collected) == 2 + evt = collected[0] + j = json.loads(evt[len("data: ") :].strip()) + assert j["object"] == "chat.completion.chunk" + assert j["choices"][0]["delta"]["role"] == "assistant" + assert j["choices"][0]["delta"]["content"] == "Hi!" + assert collected[1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_empty_string(): + """Test that empty strings are handled correctly.""" + stream = _create_test_stream([""]) + collected = [] + + async for b in _format_streaming_response(stream, model_name=None): + collected.append(b) + + assert len(collected) == 2 + evt = collected[0] + j = json.loads(evt[len("data: ") :].strip()) + assert j["choices"][0]["delta"]["content"] == "" + assert collected[1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_none_triggers_done(): + """Test that None values are handled correctly.""" + stream = _create_test_stream(["Content", None]) + collected = [] + + async for b in _format_streaming_response(stream, model_name=None): + collected.append(b) + + assert len(collected) == 3 # Content chunk, None chunk, and [DONE] + evt = collected[0] + j = json.loads(evt[len("data: ") :].strip()) + assert j["choices"][0]["delta"]["content"] == "Content" + assert collected[2] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_multiple_dict_chunks(): + """Test multiple dict chunks with different fields.""" + stream = _create_test_stream([{"role": "assistant"}, {"content": "Hello"}, {"content": " world"}]) + collected = [] + + async for b in _format_streaming_response(stream, model_name="test-model"): + collected.append(b) + + assert len(collected) == 4 + + # Check first chunk (role only) + j1 = json.loads(collected[0][len("data: ") :].strip()) + assert j1["choices"][0]["delta"]["role"] == "assistant" + assert "content" not in j1["choices"][0]["delta"] + + # Check second chunk (content only) + j2 = json.loads(collected[1][len("data: ") :].strip()) + assert j2["choices"][0]["delta"]["content"] == "Hello" + + # Check third chunk (content only) + j3 = json.loads(collected[2][len("data: ") :].strip()) + assert j3["choices"][0]["delta"]["content"] == " world" + + # Check [DONE] message + assert collected[3] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_special_characters(): + """Test that special characters are properly escaped in JSON.""" + stream = _create_test_stream(["Line 1\nLine 2", 'Quote: "test"']) + collected = [] + + async for b in _format_streaming_response(stream, model_name=None): + collected.append(b) + + assert len(collected) == 3 + + # Verify first chunk with newline + j1 = json.loads(collected[0][len("data: ") :].strip()) + assert j1["choices"][0]["delta"]["content"] == "Line 1\nLine 2" + + # Verify second chunk with quotes + j2 = json.loads(collected[1][len("data: ") :].strip()) + assert j2["choices"][0]["delta"]["content"] == 'Quote: "test"' + + assert collected[2] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_events(): + """Test that all events follow proper SSE format.""" + stream = _create_test_stream(["Test"]) + collected = [] + + async for b in _format_streaming_response(stream, model_name=None): + collected.append(b) + + # All events except [DONE] should be valid JSON with proper SSE format + for event in collected[:-1]: + assert event.startswith("data: ") + assert event.endswith("\n\n") + # Verify it's valid JSON + json_str = event[len("data: ") :].strip() + j = json.loads(json_str) + assert "object" in j + assert "choices" in j + assert isinstance(j["choices"], list) + assert len(j["choices"]) > 0 + + # Last event should be [DONE] + assert collected[-1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_chunk_metadata(): + """Test that chunk metadata is properly formatted.""" + stream = _create_test_stream(["Test"]) + collected = [] + + async for b in _format_streaming_response(stream, model_name="test-model"): + collected.append(b) + + evt = collected[0] + j = json.loads(evt[len("data: ") :].strip()) + + # Verify all required fields are present + assert "id" in j # id should be present (UUID generated) + assert j["object"] == "chat.completion.chunk" + assert isinstance(j["created"], int) + assert j["model"] == "test-model" + assert isinstance(j["choices"], list) + assert len(j["choices"]) == 1 + + choice = j["choices"][0] + assert "delta" in choice + assert choice["index"] == 0 + assert choice["finish_reason"] is None + + +@pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.") +def test_chat_completion_with_streaming(): + response = client.post( + "/v1/chat/completions", + json={ + "config_id": "general", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + }, + ) + assert response.status_code == 200 + assert response.headers["Content-Type"] == "text/event-stream" + for chunk in response.iter_lines(): + assert chunk.startswith("data: ") + assert chunk.endswith("\n\n") + assert "data: [DONE]\n\n" in response.text diff --git a/tests/test_openai_integration.py b/tests/test_openai_integration.py new file mode 100644 index 000000000..ecfd9b179 --- /dev/null +++ b/tests/test_openai_integration.py @@ -0,0 +1,161 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest +from fastapi.testclient import TestClient +from openai import OpenAI +from openai.types.chat.chat_completion import ChatCompletion, Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage + +from nemoguardrails.server import api + + +@pytest.fixture(scope="function", autouse=True) +def set_rails_config_path(): + """Set the rails_config_path to test configs.""" + original_path = api.app.rails_config_path + # Use test_configs which have mock LLMs that don't need API keys + test_configs_path = os.path.join(os.path.dirname(__file__), "test_configs") + api.app.rails_config_path = test_configs_path + yield + api.app.rails_config_path = original_path + + +@pytest.fixture(scope="function") +def openai_client(): + """Create an OpenAI client that uses the guardrails FastAPI app via TestClient.""" + # Create a TestClient for the FastAPI app + test_client = TestClient(api.app) + + client = OpenAI( + api_key="dummy-key", + base_url="http://dummy-url/v1", + http_client=test_client, + ) + return client + + +def test_openai_client_list_models(openai_client): + models = openai_client.models.list() + + # Verify the response structure matches the GuardrailsModel schema + assert len(models.data) > 0 + model = models.data[0] + # Verify it's a valid model response with required fields + assert model.object == "model" + assert model.owned_by == "nemo-guardrails" + # Check the extra fields that GuardrailsModel adds + assert hasattr(model, "config_id") + assert hasattr(model, "engine") + assert hasattr(model, "base_url") + assert model.base_url == "https://localhost:8000/v1" + + +def test_openai_client_chat_completion(openai_client): + response = openai_client.chat.completions.create( + model="with_custom_llm", + messages=[{"role": "user", "content": "hi"}], + stream=False, + ) + + assert isinstance(response, ChatCompletion) + assert response.id is not None + + assert response.choices[0] == Choice( + finish_reason="stop", + index=0, + logprobs=None, + message=ChatCompletionMessage( + content="Custom LLM response", + refusal=None, + role="assistant", + annotations=None, + audio=None, + function_call=None, + tool_calls=None, + ), + ) + assert hasattr(response, "created") + + +def test_openai_client_chat_completion_parameterized(openai_client): + response = openai_client.chat.completions.create( + model="with_custom_llm", + messages=[{"role": "user", "content": "hi"}], + temperature=0.7, + max_tokens=100, + stream=False, + ) + + # Verify response exists + assert isinstance(response, ChatCompletion) + assert response.id is not None + assert response.choices[0] == Choice( + finish_reason="stop", + index=0, + logprobs=None, + message=ChatCompletionMessage( + content="Custom LLM response", + refusal=None, + role="assistant", + annotations=None, + ), + ) + assert hasattr(response, "created") + + +def test_openai_client_chat_completion_input_rails(openai_client): + response = openai_client.chat.completions.create( + model="with_input_rails", + messages=[{"role": "user", "content": "Hello, how are you?"}], + stream=False, + ) + + # Verify response exists + assert isinstance(response, ChatCompletion) + assert response.id is not None + assert isinstance(response.choices[0], Choice) + assert hasattr(response, "created") + + +@pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.") +def test_openai_client_chat_completion_streaming(openai_client): + stream = openai_client.chat.completions.create( + model="input_rails", + messages=[{"role": "user", "content": "Tell me a short joke."}], + stream=True, + ) + + chunks = list(stream) + assert len(chunks) > 0 + + # Verify at least one chunk has content + has_content = any(hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content for chunk in chunks) + assert has_content, "At least one chunk should contain content" + + +def test_openai_client_error_handling_invalid_model(openai_client): + response = openai_client.chat.completions.create( + model="nonexistent_config", + messages=[{"role": "user", "content": "hi"}], + stream=False, + ) + + # The error should be in the content + assert ( + "Could not load" in response.choices[0].message.content + or "error" in response.choices[0].message.content.lower() + ) diff --git a/tests/test_schema_utils.py b/tests/test_schema_utils.py new file mode 100644 index 000000000..58ce576ba --- /dev/null +++ b/tests/test_schema_utils.py @@ -0,0 +1,227 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage + +from nemoguardrails.rails.llm.options import GenerationLog, GenerationResponse +from nemoguardrails.server.schemas.openai import GuardrailsChatCompletion +from nemoguardrails.server.schemas.utils import ( + create_error_chat_completion, + extract_bot_message_from_response, + format_streaming_chunk, + format_streaming_chunk_as_sse, + generation_response_to_chat_completion, +) + +# ===== Tests for extract_bot_message_from_response ===== + + +def test_extract_bot_message_from_string_response(): + """Test extracting bot message from a plain string response.""" + response = "Hello, how can I help you?" + result = extract_bot_message_from_response(response) + assert result == {"role": "assistant", "content": "Hello, how can I help you?"} + + +def test_extract_bot_message_from_dict_response(): + """Test extracting bot message from a dict response.""" + response = {"role": "assistant", "content": "Test response"} + result = extract_bot_message_from_response(response) + assert result == {"role": "assistant", "content": "Test response"} + + +def test_extract_bot_message_from_generation_response_with_string_content(): + """Test extracting bot message from GenerationResponse with string in list.""" + response = GenerationResponse(response=[{"role": "assistant", "content": "Hello from bot"}]) + result = extract_bot_message_from_response(response) + assert result == {"role": "assistant", "content": "Hello from bot"} + + +def test_extract_bot_message_from_generation_response_with_dict(): + """Test extracting bot message from GenerationResponse containing a dict.""" + bot_msg = {"role": "assistant", "content": "Response from dict"} + response = GenerationResponse(response=[bot_msg]) + result = extract_bot_message_from_response(response) + assert result == {"role": "assistant", "content": "Response from dict"} + + +def test_extract_bot_message_from_tuple_with_dict(): + """Test extracting bot message from a tuple (message, state) with dict message.""" + response = ({"role": "assistant", "content": "Tuple response"}, {"state": "data"}) + result = extract_bot_message_from_response(response) + assert result == {"role": "assistant", "content": "Tuple response"} + + +# ===== Tests for generation_response_to_chat_completion ===== + + +def test_generation_response_to_chat_completion(): + """Test converting a full GenerationResponse to chat completion.""" + response = GenerationResponse( + response=[{"role": "assistant", "content": "This is a response"}], + llm_output={"llm_output": "This is an LLM output"}, + output_data={"output_data": "This is output data"}, + log=GenerationLog(), + state={"state": "This is a state"}, + ) + result = generation_response_to_chat_completion(response=response, model="test_model", config_id="test_config_id") + assert isinstance(result, GuardrailsChatCompletion) + assert result.id.startswith("chatcmpl-") + assert isinstance(result.created, int) + + assert result.object == "chat.completion" + assert result.model == "test_model" + assert result.config_id == "test_config_id" + assert result.choices[0] == Choice( + index=0, + message=ChatCompletionMessage(role="assistant", content="This is a response"), + finish_reason="stop", + logprobs=None, + ) + assert result.llm_output == {"llm_output": "This is an LLM output"} + assert result.output_data == {"output_data": "This is output data"} + assert result.log is not None + assert result.state == {"state": "This is a state"} + + +def test_generation_response_to_chat_completion_with_empty_content(): + """Test converting GenerationResponse with missing content.""" + response = GenerationResponse(response=[{"role": "assistant", "content": ""}]) + result = generation_response_to_chat_completion(response=response, model="test_model") + assert result.choices[0].message.content == "" + + +# ===== Tests for create_error_chat_completion ===== + + +def test_create_error_chat_completion(): + """Test creating an error chat completion response.""" + error_message = "This is an error message" + config_id = "test_config_id" + result = create_error_chat_completion(model="test_model", error_message=error_message, config_id=config_id) + assert result.choices[0].message.content == error_message + assert result.model == "test_model" + assert result.config_id == config_id + assert result.object == "chat.completion" + assert result.choices[0].message.role == "assistant" + assert result.choices[0].finish_reason == "stop" + + +def test_create_error_chat_completion_without_config_id(): + """Test creating an error chat completion without config_id.""" + result = create_error_chat_completion(model="gpt-4", error_message="Error occurred") + assert result.choices[0].message.content == "Error occurred" + assert result.model == "gpt-4" + assert result.config_id is None + + +# ===== Tests for format_streaming_chunk ===== + + +def test_format_streaming_chunk_with_dict(): + """Test formatting a dict chunk.""" + chunk = {"content": "Hello"} + result = format_streaming_chunk(chunk, model="test_model") + assert result["object"] == "chat.completion.chunk" + assert result["model"] == "test_model" + assert result["choices"][0]["delta"] == {"content": "Hello"} + assert result["choices"][0]["index"] == 0 + assert result["choices"][0]["finish_reason"] is None + assert "id" in result + assert "created" in result + + +def test_format_streaming_chunk_with_plain_string(): + """Test formatting a plain string chunk.""" + chunk = "Hello world" + result = format_streaming_chunk(chunk, model="test_model") + assert result["object"] == "chat.completion.chunk" + assert result["model"] == "test_model" + assert result["choices"][0]["delta"]["content"] == "Hello world" + assert result["choices"][0]["index"] == 0 + assert result["choices"][0]["finish_reason"] is None + + +def test_format_streaming_chunk_with_json_string(): + """Test formatting a JSON string chunk.""" + chunk_data = {"custom": "data", "value": 123} + chunk = json.dumps(chunk_data) + result = format_streaming_chunk(chunk, model="test_model", chunk_id="test-id") + assert result["id"] == "test-id" + assert result["model"] == "test_model" + # Should parse the JSON and add missing fields + assert result["custom"] == "data" + assert result["value"] == 123 + + +def test_format_streaming_chunk_with_none(): + """Test formatting a None chunk.""" + chunk = None + result = format_streaming_chunk(chunk, model="test_model") + assert result["choices"][0]["delta"]["content"] == "None" + + +# ===== Tests for format_streaming_chunk_as_sse ===== + + +def test_format_streaming_chunk_as_sse_with_string(): + """Test formatting a string chunk as SSE.""" + chunk = "Hello SSE" + result = format_streaming_chunk_as_sse(chunk, model="test_model") + + assert result.startswith("data: ") + assert result.endswith("\n\n") + json_str = result[6:-2] # Remove "data: " and "\n\n" + payload = json.loads(json_str) + assert payload["object"] == "chat.completion.chunk" + assert payload["model"] == "test_model" + assert payload["choices"][0]["delta"]["content"] == "Hello SSE" + + +def test_format_streaming_chunk_as_sse_with_dict(): + """Test formatting a dict chunk as SSE.""" + chunk = {"role": "assistant", "content": "SSE response"} + result = format_streaming_chunk_as_sse(chunk, model="test_model") + assert result.startswith("data: ") + assert result.endswith("\n\n") + json_str = result[6:-2] + payload = json.loads(json_str) + assert payload["choices"][0]["delta"] == { + "role": "assistant", + "content": "SSE response", + } + + +def test_format_streaming_chunk_as_sse_with_none(): + """Test creating the streaming done event.""" + result = format_streaming_chunk_as_sse(None, model="test_model") + json_str = result[6:-2] + payload = json.loads(json_str) + assert payload["choices"][0]["delta"] == { + "content": "None", + } + + +def test_format_streaming_chunk_as_sse_with_empty_string(): + """Test creating the streaming done event.""" + result = format_streaming_chunk_as_sse("", model="test_model") + json_str = result[6:-2] + payload = json.loads(json_str) + assert payload["choices"][0]["delta"] == { + "content": "", + } diff --git a/tests/test_server_calls_with_state.py b/tests/test_server_calls_with_state.py index 051096432..736f2592c 100644 --- a/tests/test_server_calls_with_state.py +++ b/tests/test_server_calls_with_state.py @@ -37,12 +37,15 @@ def _test_call(config_id): ) assert response.status_code == 200 res = response.json() - assert len(res["messages"]) == 1 - assert res["messages"][0]["content"] == "Hello!" + print(res) + assert len(res["choices"][0]["message"]) == 2 + assert res["choices"][0]["message"]["content"] == "Hello!" assert res.get("state") # When making a second call with the returned state, the conversations should continue # and we should get the "Hello again!" message. + # For Colang 2.x, we only send the new user message, not the conversation history + # since the state maintains the conversation context. response = client.post( "/v1/chat/completions", json={ @@ -57,7 +60,7 @@ def _test_call(config_id): }, ) res = response.json() - assert res["messages"][0]["content"] == "Hello again!" + assert res["choices"][0]["message"]["content"] == "Hello again!" def test_1(): diff --git a/tests/test_streaming.py b/tests/test_streaming.py index c7f59a7d1..fa7ffaa49 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -815,3 +815,62 @@ def test_main_llm_supports_streaming_flag_disabled_when_no_streaming(): assert rails.main_llm_supports_streaming is False, ( "main_llm_supports_streaming should be False when streaming is disabled" ) + + +def test_main_llm_supports_streaming_with_multiple_model_types( + custom_streaming_providers, +): + """Test that streaming is properly configured when config has multiple model types.""" + config = RailsConfig.from_content( + config={ + "models": [ + { + "type": "main", + "engine": "custom_streaming", + "model": "test-model", + }, + { + "type": "content_safety", + "engine": "custom_streaming", + "model": "safety-model", + }, + ], + "streaming": True, + } + ) + + rails = LLMRails(config) + + assert rails.main_llm_supports_streaming is True, ( + "main_llm_supports_streaming should be True when streaming is enabled " + "and config has multiple model types including a streaming-capable main LLM" + ) + # Verify the main LLM's streaming attribute was set + assert hasattr(rails.llm, "streaming") and rails.llm.streaming is True, ( + "Main LLM's streaming attribute should be set to True" + ) + + +def test_main_llm_supports_streaming_with_specialized_models_only( + custom_streaming_providers, +): + """Test streaming config when only specialized models are defined (no main).""" + config = RailsConfig.from_content( + config={ + "models": [ + { + "type": "content_safety", + "engine": "custom_streaming", + "model": "safety-model", + }, + ], + "streaming": True, + } + ) + + rails = LLMRails(config) + + # Verify that main_llm_supports_streaming is False when no main LLM is configured + assert rails.main_llm_supports_streaming is False, ( + "main_llm_supports_streaming should be False when no main LLM is configured" + ) diff --git a/tests/test_threads.py b/tests/test_threads.py index 88946007b..baace32b7 100644 --- a/tests/test_threads.py +++ b/tests/test_threads.py @@ -51,8 +51,9 @@ def test_1(): ) assert response.status_code == 200 res = response.json() - assert len(res["messages"]) == 1 - assert res["messages"][0]["content"] == "Hello!" + assert "choices" in res + assert "message" in res["choices"][0] + assert res["choices"][0]["message"]["content"] == "Hello!" # When making a second call with the same thread_id, the conversations should continue # and we should get the "Hello again!" message. @@ -70,7 +71,7 @@ def test_1(): }, ) res = response.json() - assert res["messages"][0]["content"] == "Hello again!" + assert res["choices"][0]["message"]["content"] == "Hello again!" @pytest.mark.parametrize( @@ -138,4 +139,4 @@ def test_with_redis(): }, ) res = response.json() - assert res["messages"][0]["content"] == "Hello again!" + assert res["choices"][0]["message"]["content"] == "Hello again!"