1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15+
1516import asyncio
1617import contextvars
1718import importlib .util
2728
2829from fastapi import FastAPI , Request
2930from fastapi .middleware .cors import CORSMiddleware
30- from pydantic import BaseModel , Field , root_validator , validator
31+ from pydantic import Field , root_validator , validator
3132from starlette .responses import StreamingResponse
3233from starlette .staticfiles import StaticFiles
3334
3435from nemoguardrails import LLMRails , RailsConfig , utils
35- from nemoguardrails .rails .llm .options import (
36- GenerationOptions ,
37- GenerationResponse ,
38- )
36+ from nemoguardrails .rails .llm .options import GenerationOptions , GenerationResponse
3937from nemoguardrails .server .datastore .datastore import DataStore
38+ from nemoguardrails .server .schemas .openai import (
39+ Choice ,
40+ Model ,
41+ ModelsResponse ,
42+ OpenAIRequestFields ,
43+ ResponseBody ,
44+ )
4045from nemoguardrails .streaming import StreamingHandler
4146
4247logging .basicConfig (level = logging .INFO )
@@ -190,7 +195,7 @@ async def root_handler():
190195app .single_config_id = None
191196
192197
193- class RequestBody (BaseModel ):
198+ class RequestBody (OpenAIRequestFields ):
194199 config_id : Optional [str ] = Field (
195200 default = os .getenv ("DEFAULT_CONFIG_ID" , None ),
196201 description = "The id of the configuration to be used. If not set, the default configuration will be used." ,
@@ -228,47 +233,6 @@ class RequestBody(BaseModel):
228233 default = None ,
229234 description = "A state object that should be used to continue the interaction." ,
230235 )
231- # Standard OpenAI completion parameters
232- model : Optional [str ] = Field (
233- default = None ,
234- description = "The model to use for chat completion. Maps to config_id for backward compatibility." ,
235- )
236- max_tokens : Optional [int ] = Field (
237- default = None ,
238- description = "The maximum number of tokens to generate." ,
239- )
240- temperature : Optional [float ] = Field (
241- default = None ,
242- description = "Sampling temperature to use." ,
243- )
244- top_p : Optional [float ] = Field (
245- default = None ,
246- description = "Top-p sampling parameter." ,
247- )
248- stop : Optional [str ] = Field (
249- default = None ,
250- description = "Stop sequences." ,
251- )
252- presence_penalty : Optional [float ] = Field (
253- default = None ,
254- description = "Presence penalty parameter." ,
255- )
256- frequency_penalty : Optional [float ] = Field (
257- default = None ,
258- description = "Frequency penalty parameter." ,
259- )
260- function_call : Optional [dict ] = Field (
261- default = None ,
262- description = "Function call parameter." ,
263- )
264- logit_bias : Optional [dict ] = Field (
265- default = None ,
266- description = "Logit bias parameter." ,
267- )
268- log_probs : Optional [bool ] = Field (
269- default = None ,
270- description = "Log probabilities parameter." ,
271- )
272236
273237 @root_validator (pre = True )
274238 def ensure_config_id (cls , data : Any ) -> Any :
@@ -295,75 +259,6 @@ def ensure_config_ids(cls, v, values):
295259 return v
296260
297261
298- class Choice (BaseModel ):
299- index : Optional [int ] = Field (
300- default = None , description = "The index of the choice in the list of choices."
301- )
302- messages : Optional [dict ] = Field (
303- default = None , description = "The message of the choice"
304- )
305- logprobs : Optional [dict ] = Field (
306- default = None , description = "The log probabilities of the choice"
307- )
308- finish_reason : Optional [str ] = Field (
309- default = None , description = "The reason the model stopped generating tokens."
310- )
311-
312-
313- class ResponseBody (BaseModel ):
314- # OpenAI-compatible fields
315- id : Optional [str ] = Field (
316- default = None , description = "A unique identifier for the chat completion."
317- )
318- object : str = Field (
319- default = "chat.completion" ,
320- description = "The object type, which is always chat.completion" ,
321- )
322- created : Optional [int ] = Field (
323- default = None ,
324- description = "The Unix timestamp (in seconds) of when the chat completion was created." ,
325- )
326- model : Optional [str ] = Field (
327- default = None , description = "The model used for the chat completion."
328- )
329- choices : Optional [List [Choice ]] = Field (
330- default = None , description = "A list of chat completion choices."
331- )
332- # NeMo-Guardrails specific fields for backward compatibility
333- state : Optional [dict ] = Field (
334- default = None , description = "State object for continuing the conversation."
335- )
336- llm_output : Optional [dict ] = Field (
337- default = None , description = "Additional LLM output data."
338- )
339- output_data : Optional [dict ] = Field (
340- default = None , description = "Additional output data."
341- )
342- log : Optional [dict ] = Field (default = None , description = "Generation log data." )
343-
344-
345- class Model (BaseModel ):
346- id : str = Field (
347- description = "The model identifier, which can be referenced in the API endpoints."
348- )
349- object : str = Field (
350- default = "model" , description = "The object type, which is always 'model'."
351- )
352- created : int = Field (
353- description = "The Unix timestamp (in seconds) of when the model was created."
354- )
355- owned_by : str = Field (
356- default = "nemo-guardrails" , description = "The organization that owns the model."
357- )
358-
359-
360- class ModelsResponse (BaseModel ):
361- object : str = Field (
362- default = "list" , description = "The object type, which is always 'list'."
363- )
364- data : List [Model ] = Field (description = "The list of models." )
365-
366-
367262@app .get (
368263 "/v1/models" ,
369264 response_model = ModelsResponse ,
@@ -538,7 +433,7 @@ async def chat_completion(body: RequestBody, request: Request):
538433 choices = [
539434 Choice (
540435 index = 0 ,
541- messages = {
436+ message = {
542437 "content" : f"Could not load the { config_ids } guardrails configuration. "
543438 f"An internal error has occurred." ,
544439 "role" : "assistant" ,
@@ -571,7 +466,7 @@ async def chat_completion(body: RequestBody, request: Request):
571466 choices = [
572467 Choice (
573468 index = 0 ,
574- messages = {
469+ message = {
575470 "content" : "The `thread_id` must have a minimum length of 16 characters." ,
576471 "role" : "assistant" ,
577472 },
@@ -589,11 +484,13 @@ async def chat_completion(body: RequestBody, request: Request):
589484 # And prepend them.
590485 messages = thread_messages + messages
591486
592- # Map OpenAI-compatible parameters to generation options
593487 generation_options = body .options
488+
594489 # Initialize llm_params if not already set
595490 if generation_options .llm_params is None :
596491 generation_options .llm_params = {}
492+
493+ # Set OpenAI-compatible parameters in llm_params
597494 if body .max_tokens :
598495 generation_options .llm_params ["max_tokens" ] = body .max_tokens
599496 if body .temperature is not None :
@@ -656,7 +553,7 @@ async def chat_completion(body: RequestBody, request: Request):
656553 "choices" : [
657554 Choice (
658555 index = 0 ,
659- messages = bot_message ,
556+ message = bot_message ,
660557 finish_reason = "stop" ,
661558 logprobs = None ,
662559 )
@@ -682,7 +579,7 @@ async def chat_completion(body: RequestBody, request: Request):
682579 choices = [
683580 Choice (
684581 index = 0 ,
685- messages = {
582+ message = {
686583 "content" : "Internal server error" ,
687584 "role" : "assistant" ,
688585 },
0 commit comments