Skip to content

Commit 416ac39

Browse files
Add OpenAI docs and integration tests
1 parent aaac161 commit 416ac39

File tree

7 files changed

+258
-75
lines changed

7 files changed

+258
-75
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
## OpenAI API Compatibility for NeMo Guardrails
2+
3+
NeMo Guardrails provides server-side compatibility with OpenAI API endpoints, enabling applications that use OpenAI clients to seamlessly integrate with NeMo Guardrails for adding guardrails to LLM interactions. Point your OpenAI client to `http://localhost:8000` (or your server URL) and use the standard `/v1/chat/completions` endpoint.
4+
5+
## Feature Support Matrix
6+
7+
The following table outlines which OpenAI API features are currently supported when using NeMo Guardrails:
8+
9+
| Feature | Status | Notes |
10+
| :------ | :----: | :---- |
11+
| **Basic Chat Completion** | ✔ Supported | Full support for standard chat completions with guardrails applied |
12+
| **Streaming Responses** | ✔ Supported | Server-Sent Events (SSE) streaming with `stream=true` |
13+
| **Multimodal Input** | ✖ Unsupported | Support for text and image inputs (vision models) with guardrails but not yet OpenAI compatible |
14+
| **Function Calling** | ✖ Unsupported | Not yet implemented; guardrails need structured output support |
15+
| **Tools** | ✖ Unsupported | Related to function calling; requires action flow integration |
16+
| **Response Format (JSON Mode)** | ✖ Unsupported | Structured output with guardrails requires additional validation logic |

nemoguardrails/server/api.py

Lines changed: 60 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828

2929
from fastapi import FastAPI, Request
3030
from fastapi.middleware.cors import CORSMiddleware
31-
from openai.types.chat.chat_completion import ChatCompletion, Choice
31+
from openai.types.chat.chat_completion import Choice
32+
from openai.types.chat.chat_completion_message import ChatCompletionMessage
3233
from openai.types.model import Model
3334
from pydantic import BaseModel, Field, root_validator, validator
3435
from starlette.responses import StreamingResponse
@@ -191,7 +192,7 @@ async def root_handler():
191192
app.single_config_id = None
192193

193194

194-
class RequestBody(ChatCompletion):
195+
class RequestBody(BaseModel):
195196
config_id: Optional[str] = Field(
196197
default=os.getenv("DEFAULT_CONFIG_ID", None),
197198
description="The id of the configuration to be used. If not set, the default configuration will be used.",
@@ -208,70 +209,67 @@ class RequestBody(ChatCompletion):
208209
max_length=255,
209210
description="The id of an existing thread to which the messages should be added.",
210211
)
211-
model: Optional[str] = Field(
212-
default=None,
213-
description="The model used for the chat completion.",
212+
messages: Optional[List[dict]] = Field(
213+
default=None, description="The list of messages in the current conversation."
214214
)
215-
id: Optional[str] = Field(
215+
context: Optional[dict] = Field(
216216
default=None,
217-
description="The id of the chat completion.",
217+
description="Additional context data to be added to the conversation.",
218+
)
219+
stream: Optional[bool] = Field(
220+
default=False,
221+
description="If set, partial message deltas will be sent, like in ChatGPT. "
222+
"Tokens will be sent as data-only server-sent events as they become "
223+
"available, with the stream terminated by a data: [DONE] message.",
218224
)
219-
object: Optional[str] = Field(
220-
default="chat.completion",
221-
description="The object type, which is always chat.completion",
225+
options: GenerationOptions = Field(
226+
default_factory=GenerationOptions,
227+
description="Additional options for controlling the generation.",
222228
)
223-
created: Optional[int] = Field(
229+
state: Optional[dict] = Field(
224230
default=None,
225-
description="The Unix timestamp (in seconds) of when the chat completion was created.",
231+
description="A state object that should be used to continue the interaction.",
226232
)
227-
choices: Optional[List[Choice]] = Field(
233+
# Standard OpenAI completion parameters
234+
model: Optional[str] = Field(
228235
default=None,
229-
description="The list of choices for the chat completion.",
236+
description="The model to use for chat completion. Maps to config_id for backward compatibility.",
230237
)
231238
max_tokens: Optional[int] = Field(
232239
default=None,
233240
description="The maximum number of tokens to generate.",
234241
)
235242
temperature: Optional[float] = Field(
236243
default=None,
237-
description="The temperature to use for the chat completion.",
244+
description="Sampling temperature to use.",
238245
)
239246
top_p: Optional[float] = Field(
240247
default=None,
241-
description="The top p to use for the chat completion.",
248+
description="Top-p sampling parameter.",
242249
)
243-
stop: Optional[Union[str, List[str]]] = Field(
250+
stop: Optional[str] = Field(
244251
default=None,
245-
description="The stop sequences to use for the chat completion.",
252+
description="Stop sequences.",
246253
)
247254
presence_penalty: Optional[float] = Field(
248255
default=None,
249-
description="The presence penalty to use for the chat completion.",
256+
description="Presence penalty parameter.",
250257
)
251258
frequency_penalty: Optional[float] = Field(
252259
default=None,
253-
description="The frequency penalty to use for the chat completion.",
260+
description="Frequency penalty parameter.",
254261
)
255-
messages: Optional[List[dict]] = Field(
256-
default=None, description="The list of messages in the current conversation."
257-
)
258-
context: Optional[dict] = Field(
262+
function_call: Optional[dict] = Field(
259263
default=None,
260-
description="Additional context data to be added to the conversation.",
264+
description="Function call parameter.",
261265
)
262-
stream: Optional[bool] = Field(
263-
default=False,
264-
description="If set, partial message deltas will be sent, like in ChatGPT. "
265-
"Tokens will be sent as data-only server-sent events as they become "
266-
"available, with the stream terminated by a data: [DONE] message.",
267-
)
268-
options: GenerationOptions = Field(
269-
default_factory=GenerationOptions,
270-
description="Additional options for controlling the generation.",
266+
logit_bias: Optional[dict] = Field(
267+
default=None,
268+
description="Logit bias parameter.",
271269
)
272-
state: Optional[dict] = Field(
270+
log_probs: Optional[bool] = Field(
273271
default=None,
274-
description="A state object that should be used to continue the interaction.",
272+
description="Log probabilities parameter.",
275273
)
276274

277275
@root_validator(pre=True)
@@ -453,7 +451,7 @@ async def _format_streaming_response(
453451
"choices": [
454452
{
455453
"delta": chunk,
456-
"index": None,
454+
"index": 0,
457455
"finish_reason": None,
458456
}
459457
],
@@ -472,7 +470,7 @@ async def _format_streaming_response(
472470
"choices": [
473471
{
474472
"delta": {"content": chunk},
475-
"index": None,
473+
"index": 0,
476474
"finish_reason": None,
477475
}
478476
],
@@ -487,7 +485,7 @@ async def _format_streaming_response(
487485
"choices": [
488486
{
489487
"delta": {"content": str(chunk)},
490-
"index": None,
488+
"index": 0,
491489
"finish_reason": None,
492490
}
493491
],
@@ -536,16 +534,16 @@ async def chat_completion(body: RequestBody, request: Request):
536534
id=f"chatcmpl-{uuid.uuid4()}",
537535
object="chat.completion",
538536
created=int(time.time()),
539-
model=config_ids[0] if config_ids else None,
537+
model=config_ids[0] if config_ids else "unknown",
540538
choices=[
541539
Choice(
542540
index=0,
543-
message={
544-
"content": f"Could not load the {config_ids} guardrails configuration. "
541+
message=ChatCompletionMessage(
542+
content=f"Could not load the {config_ids} guardrails configuration. "
545543
f"An internal error has occurred.",
546-
"role": "assistant",
547-
},
548-
finish_reason="error",
544+
role="assistant",
545+
),
546+
finish_reason="stop",
549547
logprobs=None,
550548
)
551549
],
@@ -569,15 +567,15 @@ async def chat_completion(body: RequestBody, request: Request):
569567
id=f"chatcmpl-{uuid.uuid4()}",
570568
object="chat.completion",
571569
created=int(time.time()),
572-
model=None,
570+
model=config_ids[0] if config_ids else "unknown",
573571
choices=[
574572
Choice(
575573
index=0,
576-
message={
577-
"content": "The `thread_id` must have a minimum length of 16 characters.",
578-
"role": "assistant",
579-
},
580-
finish_reason="error",
574+
message=ChatCompletionMessage(
575+
content="The `thread_id` must have a minimum length of 16 characters.",
576+
role="assistant",
577+
),
578+
finish_reason="stop",
581579
logprobs=None,
582580
)
583581
],
@@ -661,11 +659,14 @@ async def chat_completion(body: RequestBody, request: Request):
661659
"id": f"chatcmpl-{uuid.uuid4()}",
662660
"object": "chat.completion",
663661
"created": int(time.time()),
664-
"model": config_ids[0] if config_ids else None,
662+
"model": config_ids[0] if config_ids else "unknown",
665663
"choices": [
666664
Choice(
667665
index=0,
668-
message=bot_message,
666+
message=ChatCompletionMessage(
667+
role="assistant",
668+
content=bot_message["content"],
669+
),
669670
finish_reason="stop",
670671
logprobs=None,
671672
)
@@ -687,15 +688,15 @@ async def chat_completion(body: RequestBody, request: Request):
687688
id=f"chatcmpl-{uuid.uuid4()}",
688689
object="chat.completion",
689690
created=int(time.time()),
690-
model=None,
691+
model="unknown",
691692
choices=[
692693
Choice(
693694
index=0,
694-
message={
695-
"content": "Internal server error",
696-
"role": "assistant",
697-
},
698-
finish_reason="error",
695+
message=ChatCompletionMessage(
696+
content="Internal server error",
697+
role="assistant",
698+
),
699+
finish_reason="stop",
699700
logprobs=None,
700701
)
701702
],

nemoguardrails/server/schemas/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from typing import List, Optional
1919

20-
from openai.types.chat.chat_completion import ChatCompletion, Choice
20+
from openai.types.chat.chat_completion import ChatCompletion
2121
from openai.types.model import Model
2222
from pydantic import BaseModel, Field
2323

poetry.lock

Lines changed: 10 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,13 @@ starlette = ">=0.49.1"
7171
typer = ">=0.8"
7272
uvicorn = ">=0.23"
7373
watchdog = ">=3.0.0,"
74+
aiofiles = ">=24.1.0"
75+
openai = ">=1.0.0, <2.0.0"
7476

7577
# tracing
7678
opentelemetry-api = { version = ">=1.27.0,<2.0.0", optional = true }
77-
aiofiles = { version = ">=24.1.0", optional = true }
7879

7980
# openai
80-
openai = { version = ">=1.0.0, <2.0.0", optional = true }
8181
langchain-openai = { version = ">=0.1.0", optional = true }
8282

8383
# eval
@@ -111,7 +111,7 @@ sdd = ["presidio-analyzer", "presidio-anonymizer"]
111111
eval = ["tqdm", "numpy", "streamlit", "tornado"]
112112
openai = ["langchain-openai"]
113113
gcp = ["google-cloud-language"]
114-
tracing = ["opentelemetry-api", "aiofiles"]
114+
tracing = ["opentelemetry-api"]
115115
nvidia = ["langchain-nvidia-ai-endpoints"]
116116
jailbreak = ["yara-python"]
117117
# Poetry does not support recursive dependencies, so we need to add all the dependencies here.
@@ -126,7 +126,6 @@ all = [
126126
"langchain-openai",
127127
"google-cloud-language",
128128
"opentelemetry-api",
129-
"aiofiles",
130129
"langchain-nvidia-ai-endpoints",
131130
"yara-python",
132131
]

tests/test_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ async def collector():
480480

481481
choice = j["choices"][0]
482482
assert "delta" in choice
483-
assert choice["index"] is None
483+
assert choice["index"] == 0
484484
assert choice["finish_reason"] is None
485485

486486

0 commit comments

Comments
 (0)