Skip to content

Commit 4cdf232

Browse files
chore: Move OpenAPI schema and fix typos
1 parent 9970296 commit 4cdf232

File tree

4 files changed

+168
-128
lines changed

4 files changed

+168
-128
lines changed

nemoguardrails/server/api.py

Lines changed: 18 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
1516
import asyncio
1617
import contextvars
1718
import importlib.util
@@ -27,16 +28,20 @@
2728

2829
from fastapi import FastAPI, Request
2930
from fastapi.middleware.cors import CORSMiddleware
30-
from pydantic import BaseModel, Field, root_validator, validator
31+
from pydantic import Field, root_validator, validator
3132
from starlette.responses import StreamingResponse
3233
from starlette.staticfiles import StaticFiles
3334

3435
from nemoguardrails import LLMRails, RailsConfig, utils
35-
from nemoguardrails.rails.llm.options import (
36-
GenerationOptions,
37-
GenerationResponse,
38-
)
36+
from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse
3937
from nemoguardrails.server.datastore.datastore import DataStore
38+
from nemoguardrails.server.schemas.openai import (
39+
Choice,
40+
Model,
41+
ModelsResponse,
42+
OpenAIRequestFields,
43+
ResponseBody,
44+
)
4045
from nemoguardrails.streaming import StreamingHandler
4146

4247
logging.basicConfig(level=logging.INFO)
@@ -190,7 +195,7 @@ async def root_handler():
190195
app.single_config_id = None
191196

192197

193-
class RequestBody(BaseModel):
198+
class RequestBody(OpenAIRequestFields):
194199
config_id: Optional[str] = Field(
195200
default=os.getenv("DEFAULT_CONFIG_ID", None),
196201
description="The id of the configuration to be used. If not set, the default configuration will be used.",
@@ -228,47 +233,6 @@ class RequestBody(BaseModel):
228233
default=None,
229234
description="A state object that should be used to continue the interaction.",
230235
)
231-
# Standard OpenAI completion parameters
232-
model: Optional[str] = Field(
233-
default=None,
234-
description="The model to use for chat completion. Maps to config_id for backward compatibility.",
235-
)
236-
max_tokens: Optional[int] = Field(
237-
default=None,
238-
description="The maximum number of tokens to generate.",
239-
)
240-
temperature: Optional[float] = Field(
241-
default=None,
242-
description="Sampling temperature to use.",
243-
)
244-
top_p: Optional[float] = Field(
245-
default=None,
246-
description="Top-p sampling parameter.",
247-
)
248-
stop: Optional[str] = Field(
249-
default=None,
250-
description="Stop sequences.",
251-
)
252-
presence_penalty: Optional[float] = Field(
253-
default=None,
254-
description="Presence penalty parameter.",
255-
)
256-
frequency_penalty: Optional[float] = Field(
257-
default=None,
258-
description="Frequency penalty parameter.",
259-
)
260-
function_call: Optional[dict] = Field(
261-
default=None,
262-
description="Function call parameter.",
263-
)
264-
logit_bias: Optional[dict] = Field(
265-
default=None,
266-
description="Logit bias parameter.",
267-
)
268-
log_probs: Optional[bool] = Field(
269-
default=None,
270-
description="Log probabilities parameter.",
271-
)
272236

273237
@root_validator(pre=True)
274238
def ensure_config_id(cls, data: Any) -> Any:
@@ -295,75 +259,6 @@ def ensure_config_ids(cls, v, values):
295259
return v
296260

297261

298-
class Choice(BaseModel):
299-
index: Optional[int] = Field(
300-
default=None, description="The index of the choice in the list of choices."
301-
)
302-
messages: Optional[dict] = Field(
303-
default=None, description="The message of the choice"
304-
)
305-
logprobs: Optional[dict] = Field(
306-
default=None, description="The log probabilities of the choice"
307-
)
308-
finish_reason: Optional[str] = Field(
309-
default=None, description="The reason the model stopped generating tokens."
310-
)
311-
312-
313-
class ResponseBody(BaseModel):
314-
# OpenAI-compatible fields
315-
id: Optional[str] = Field(
316-
default=None, description="A unique identifier for the chat completion."
317-
)
318-
object: str = Field(
319-
default="chat.completion",
320-
description="The object type, which is always chat.completion",
321-
)
322-
created: Optional[int] = Field(
323-
default=None,
324-
description="The Unix timestamp (in seconds) of when the chat completion was created.",
325-
)
326-
model: Optional[str] = Field(
327-
default=None, description="The model used for the chat completion."
328-
)
329-
choices: Optional[List[Choice]] = Field(
330-
default=None, description="A list of chat completion choices."
331-
)
332-
# NeMo-Guardrails specific fields for backward compatibility
333-
state: Optional[dict] = Field(
334-
default=None, description="State object for continuing the conversation."
335-
)
336-
llm_output: Optional[dict] = Field(
337-
default=None, description="Additional LLM output data."
338-
)
339-
output_data: Optional[dict] = Field(
340-
default=None, description="Additional output data."
341-
)
342-
log: Optional[dict] = Field(default=None, description="Generation log data.")
343-
344-
345-
class Model(BaseModel):
346-
id: str = Field(
347-
description="The model identifier, which can be referenced in the API endpoints."
348-
)
349-
object: str = Field(
350-
default="model", description="The object type, which is always 'model'."
351-
)
352-
created: int = Field(
353-
description="The Unix timestamp (in seconds) of when the model was created."
354-
)
355-
owned_by: str = Field(
356-
default="nemo-guardrails", description="The organization that owns the model."
357-
)
358-
359-
360-
class ModelsResponse(BaseModel):
361-
object: str = Field(
362-
default="list", description="The object type, which is always 'list'."
363-
)
364-
data: List[Model] = Field(description="The list of models.")
365-
366-
367262
@app.get(
368263
"/v1/models",
369264
response_model=ModelsResponse,
@@ -538,7 +433,7 @@ async def chat_completion(body: RequestBody, request: Request):
538433
choices=[
539434
Choice(
540435
index=0,
541-
messages={
436+
message={
542437
"content": f"Could not load the {config_ids} guardrails configuration. "
543438
f"An internal error has occurred.",
544439
"role": "assistant",
@@ -571,7 +466,7 @@ async def chat_completion(body: RequestBody, request: Request):
571466
choices=[
572467
Choice(
573468
index=0,
574-
messages={
469+
message={
575470
"content": "The `thread_id` must have a minimum length of 16 characters.",
576471
"role": "assistant",
577472
},
@@ -589,11 +484,13 @@ async def chat_completion(body: RequestBody, request: Request):
589484
# And prepend them.
590485
messages = thread_messages + messages
591486

592-
# Map OpenAI-compatible parameters to generation options
593487
generation_options = body.options
488+
594489
# Initialize llm_params if not already set
595490
if generation_options.llm_params is None:
596491
generation_options.llm_params = {}
492+
493+
# Set OpenAI-compatible parameters in llm_params
597494
if body.max_tokens:
598495
generation_options.llm_params["max_tokens"] = body.max_tokens
599496
if body.temperature is not None:
@@ -656,7 +553,7 @@ async def chat_completion(body: RequestBody, request: Request):
656553
"choices": [
657554
Choice(
658555
index=0,
659-
messages=bot_message,
556+
message=bot_message,
660557
finish_reason="stop",
661558
logprobs=None,
662559
)
@@ -682,7 +579,7 @@ async def chat_completion(body: RequestBody, request: Request):
682579
choices=[
683580
Choice(
684581
index=0,
685-
messages={
582+
message={
686583
"content": "Internal server error",
687584
"role": "assistant",
688585
},
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""OpenAI API schema definitions for the NeMo Guardrails server."""
17+
18+
from typing import List, Optional, Union
19+
20+
from pydantic import BaseModel, Field
21+
22+
23+
class OpenAIRequestFields(BaseModel):
24+
"""OpenAI API request fields that can be mixed into other request schemas."""
25+
26+
# Standard OpenAI completion parameters
27+
model: Optional[str] = Field(
28+
default=None,
29+
description="The model to use for chat completion. Maps to config_id for backward compatibility.",
30+
)
31+
max_tokens: Optional[int] = Field(
32+
default=None,
33+
description="The maximum number of tokens to generate.",
34+
)
35+
temperature: Optional[float] = Field(
36+
default=None,
37+
description="Sampling temperature to use.",
38+
)
39+
top_p: Optional[float] = Field(
40+
default=None,
41+
description="Top-p sampling parameter.",
42+
)
43+
stop: Optional[Union[str, List[str]]] = Field(
44+
default=None,
45+
description="Stop sequences.",
46+
)
47+
presence_penalty: Optional[float] = Field(
48+
default=None,
49+
description="Presence penalty parameter.",
50+
)
51+
frequency_penalty: Optional[float] = Field(
52+
default=None,
53+
description="Frequency penalty parameter.",
54+
)
55+
function_call: Optional[dict] = Field(
56+
default=None,
57+
description="Function call parameter.",
58+
)
59+
logit_bias: Optional[dict] = Field(
60+
default=None,
61+
description="Logit bias parameter.",
62+
)
63+
log_probs: Optional[bool] = Field(
64+
default=None,
65+
description="Log probabilities parameter.",
66+
)
67+
68+
69+
class Choice(BaseModel):
70+
"""OpenAI API choice structure in chat completion responses."""
71+
72+
index: Optional[int] = Field(
73+
default=None, description="The index of the choice in the list of choices."
74+
)
75+
message: Optional[dict] = Field(
76+
default=None, description="The message of the choice"
77+
)
78+
logprobs: Optional[dict] = Field(
79+
default=None, description="The log probabilities of the choice"
80+
)
81+
finish_reason: Optional[str] = Field(
82+
default=None, description="The reason the model stopped generating tokens."
83+
)
84+
85+
86+
class ResponseBody(BaseModel):
87+
"""OpenAI API response body with NeMo-Guardrails extensions."""
88+
89+
# OpenAI API fields
90+
id: Optional[str] = Field(
91+
default=None, description="A unique identifier for the chat completion."
92+
)
93+
object: str = Field(
94+
default="chat.completion",
95+
description="The object type, which is always chat.completion",
96+
)
97+
created: Optional[int] = Field(
98+
default=None,
99+
description="The Unix timestamp (in seconds) of when the chat completion was created.",
100+
)
101+
model: Optional[str] = Field(
102+
default=None, description="The model used for the chat completion."
103+
)
104+
choices: Optional[List[Choice]] = Field(
105+
default=None, description="A list of chat completion choices."
106+
)
107+
# NeMo-Guardrails specific fields for backward compatibility
108+
state: Optional[dict] = Field(
109+
default=None, description="State object for continuing the conversation."
110+
)
111+
llm_output: Optional[dict] = Field(
112+
default=None, description="Additional LLM output data."
113+
)
114+
output_data: Optional[dict] = Field(
115+
default=None, description="Additional output data."
116+
)
117+
log: Optional[dict] = Field(default=None, description="Generation log data.")
118+
119+
120+
class Model(BaseModel):
121+
"""OpenAI API model representation."""
122+
123+
id: str = Field(
124+
description="The model identifier, which can be referenced in the API endpoints."
125+
)
126+
object: str = Field(
127+
default="model", description="The object type, which is always 'model'."
128+
)
129+
created: int = Field(
130+
description="The Unix timestamp (in seconds) of when the model was created."
131+
)
132+
owned_by: str = Field(
133+
default="nemo-guardrails", description="The organization that owns the model."
134+
)
135+
136+
137+
class ModelsResponse(BaseModel):
138+
"""OpenAI API models list response."""
139+
140+
object: str = Field(
141+
default="list", description="The object type, which is always 'list'."
142+
)
143+
data: List[Model] = Field(description="The list of models.")

tests/test_server_calls_with_state.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def _test_call(config_id):
3838
assert response.status_code == 200
3939
res = response.json()
4040
print(res)
41-
assert len(res["choices"][0]["messages"]) == 2
42-
assert res["choices"][0]["messages"]["content"] == "Hello!"
41+
assert len(res["choices"][0]["message"]) == 2
42+
assert res["choices"][0]["message"]["content"] == "Hello!"
4343
assert res.get("state")
4444

4545
# When making a second call with the returned state, the conversations should continue
@@ -60,7 +60,7 @@ def _test_call(config_id):
6060
},
6161
)
6262
res = response.json()
63-
assert res["choices"][0]["messages"]["content"] == "Hello again!"
63+
assert res["choices"][0]["message"]["content"] == "Hello again!"
6464

6565

6666
def test_1():

0 commit comments

Comments
 (0)