2424import uuid
2525import warnings
2626from contextlib import asynccontextmanager
27- from typing import Any , Callable , List , Optional
27+ from typing import Any , AsyncIterator , Callable , List , Optional , Union
2828
2929from fastapi import FastAPI , Request
3030from fastapi .middleware .cors import CORSMiddleware
31- from pydantic import Field , root_validator , validator
31+ from openai .types .chat .chat_completion import ChatCompletion , Choice
32+ from openai .types .model import Model
33+ from pydantic import BaseModel , Field , root_validator , validator
3234from starlette .responses import StreamingResponse
3335from starlette .staticfiles import StaticFiles
3436
3537from nemoguardrails import LLMRails , RailsConfig , utils
3638from nemoguardrails .rails .llm .options import GenerationOptions , GenerationResponse
3739from nemoguardrails .server .datastore .datastore import DataStore
38- from nemoguardrails .server .schemas .openai import (
39- Choice ,
40- Model ,
41- ModelsResponse ,
42- OpenAIRequestFields ,
43- ResponseBody ,
44- )
40+ from nemoguardrails .server .schemas .openai import ModelsResponse , ResponseBody
4541from nemoguardrails .streaming import StreamingHandler
4642
4743logging .basicConfig (level = logging .INFO )
@@ -195,7 +191,7 @@ async def root_handler():
195191app .single_config_id = None
196192
197193
198- class RequestBody (OpenAIRequestFields ):
194+ class RequestBody (ChatCompletion ):
199195 config_id : Optional [str ] = Field (
200196 default = os .getenv ("DEFAULT_CONFIG_ID" , None ),
201197 description = "The id of the configuration to be used. If not set, the default configuration will be used." ,
@@ -212,6 +208,50 @@ class RequestBody(OpenAIRequestFields):
212208 max_length = 255 ,
213209 description = "The id of an existing thread to which the messages should be added." ,
214210 )
211+ model : Optional [str ] = Field (
212+ default = None ,
213+ description = "The model used for the chat completion." ,
214+ )
215+ id : Optional [str ] = Field (
216+ default = None ,
217+ description = "The id of the chat completion." ,
218+ )
219+ object : Optional [str ] = Field (
220+ default = "chat.completion" ,
221+ description = "The object type, which is always chat.completion" ,
222+ )
223+ created : Optional [int ] = Field (
224+ default = None ,
225+ description = "The Unix timestamp (in seconds) of when the chat completion was created." ,
226+ )
227+ choices : Optional [List [Choice ]] = Field (
228+ default = None ,
229+ description = "The list of choices for the chat completion." ,
230+ )
231+ max_tokens : Optional [int ] = Field (
232+ default = None ,
233+ description = "The maximum number of tokens to generate." ,
234+ )
235+ temperature : Optional [float ] = Field (
236+ default = None ,
237+ description = "The temperature to use for the chat completion." ,
238+ )
239+ top_p : Optional [float ] = Field (
240+ default = None ,
241+ description = "The top p to use for the chat completion." ,
242+ )
243+ stop : Optional [Union [str , List [str ]]] = Field (
244+ default = None ,
245+ description = "The stop sequences to use for the chat completion." ,
246+ )
247+ presence_penalty : Optional [float ] = Field (
248+ default = None ,
249+ description = "The presence penalty to use for the chat completion." ,
250+ )
251+ frequency_penalty : Optional [float ] = Field (
252+ default = None ,
253+ description = "The frequency penalty to use for the chat completion." ,
254+ )
215255 messages : Optional [List [dict ]] = Field (
216256 default = None , description = "The list of messages in the current conversation."
217257 )
@@ -391,6 +431,73 @@ def _get_rails(config_ids: List[str]) -> LLMRails:
391431 return llm_rails
392432
393433
434+ async def _format_streaming_response (
435+ streaming_handler : StreamingHandler , model_name : Optional [str ]
436+ ) -> AsyncIterator [str ]:
437+ while True :
438+ try :
439+ chunk = await streaming_handler .__anext__ ()
440+ except StopAsyncIteration :
441+ # When the stream ends, yield the [DONE] message
442+ yield "data: [DONE]\n \n "
443+ break
444+
445+ # Determine the payload format based on chunk type
446+ if isinstance (chunk , dict ):
447+ # If chunk is a dict, wrap it in OpenAI chunk format with delta
448+ payload = {
449+ "id" : None ,
450+ "object" : "chat.completion.chunk" ,
451+ "created" : int (time .time ()),
452+ "model" : model_name ,
453+ "choices" : [
454+ {
455+ "delta" : chunk ,
456+ "index" : None ,
457+ "finish_reason" : None ,
458+ }
459+ ],
460+ }
461+ elif isinstance (chunk , str ):
462+ try :
463+ # Try parsing as JSON - if it parses, it might be a pre-formed payload
464+ payload = json .loads (chunk )
465+ except Exception :
466+ # treat as plain text content token
467+ payload = {
468+ "id" : None ,
469+ "object" : "chat.completion.chunk" ,
470+ "created" : int (time .time ()),
471+ "model" : model_name ,
472+ "choices" : [
473+ {
474+ "delta" : {"content" : chunk },
475+ "index" : None ,
476+ "finish_reason" : None ,
477+ }
478+ ],
479+ }
480+ else :
481+ # For any other type, treat as plain content
482+ payload = {
483+ "id" : None ,
484+ "object" : "chat.completion.chunk" ,
485+ "created" : int (time .time ()),
486+ "model" : model_name ,
487+ "choices" : [
488+ {
489+ "delta" : {"content" : str (chunk )},
490+ "index" : None ,
491+ "finish_reason" : None ,
492+ }
493+ ],
494+ }
495+
496+ # Send the payload as JSON
497+ data = json .dumps (payload , ensure_ascii = False )
498+ yield f"data: { data } \n \n "
499+
500+
394501@app .post (
395502 "/v1/chat/completions" ,
396503 response_model = ResponseBody ,
@@ -522,7 +629,12 @@ async def chat_completion(body: RequestBody, request: Request):
522629 )
523630 )
524631
525- return StreamingResponse (streaming_handler )
632+ return StreamingResponse (
633+ _format_streaming_response (
634+ streaming_handler , model_name = config_ids [0 ] if config_ids else None
635+ ),
636+ media_type = "text/event-stream" ,
637+ )
526638 else :
527639 res = await llm_rails .generate_async (
528640 messages = messages , options = generation_options , state = body .state
0 commit comments