2828
2929from fastapi import FastAPI , Request
3030from fastapi .middleware .cors import CORSMiddleware
31- from openai .types .chat .chat_completion import ChatCompletion , Choice
31+ from openai .types .chat .chat_completion import Choice
32+ from openai .types .chat .chat_completion_message import ChatCompletionMessage
3233from openai .types .model import Model
3334from pydantic import BaseModel , Field , root_validator , validator
3435from starlette .responses import StreamingResponse
@@ -191,7 +192,7 @@ async def root_handler():
191192app .single_config_id = None
192193
193194
194- class RequestBody (ChatCompletion ):
195+ class RequestBody (BaseModel ):
195196 config_id : Optional [str ] = Field (
196197 default = os .getenv ("DEFAULT_CONFIG_ID" , None ),
197198 description = "The id of the configuration to be used. If not set, the default configuration will be used." ,
@@ -208,70 +209,67 @@ class RequestBody(ChatCompletion):
208209 max_length = 255 ,
209210 description = "The id of an existing thread to which the messages should be added." ,
210211 )
211- model : Optional [str ] = Field (
212- default = None ,
213- description = "The model used for the chat completion." ,
212+ messages : Optional [List [dict ]] = Field (
213+ default = None , description = "The list of messages in the current conversation."
214214 )
215- id : Optional [str ] = Field (
215+ context : Optional [dict ] = Field (
216216 default = None ,
217- description = "The id of the chat completion." ,
217+ description = "Additional context data to be added to the conversation." ,
218+ )
219+ stream : Optional [bool ] = Field (
220+ default = False ,
221+ description = "If set, partial message deltas will be sent, like in ChatGPT. "
222+ "Tokens will be sent as data-only server-sent events as they become "
223+ "available, with the stream terminated by a data: [DONE] message." ,
218224 )
219- object : Optional [ str ] = Field (
220- default = "chat.completion" ,
221- description = "The object type, which is always chat.completion " ,
225+ options : GenerationOptions = Field (
226+ default_factory = GenerationOptions ,
227+ description = "Additional options for controlling the generation. " ,
222228 )
223- created : Optional [int ] = Field (
229+ state : Optional [dict ] = Field (
224230 default = None ,
225- description = "The Unix timestamp (in seconds) of when the chat completion was created ." ,
231+ description = "A state object that should be used to continue the interaction ." ,
226232 )
227- choices : Optional [List [Choice ]] = Field (
233+ # Standard OpenAI completion parameters
234+ model : Optional [str ] = Field (
228235 default = None ,
229- description = "The list of choices for the chat completion." ,
236+ description = "The model to use for chat completion. Maps to config_id for backward compatibility ." ,
230237 )
231238 max_tokens : Optional [int ] = Field (
232239 default = None ,
233240 description = "The maximum number of tokens to generate." ,
234241 )
235242 temperature : Optional [float ] = Field (
236243 default = None ,
237- description = "The temperature to use for the chat completion ." ,
244+ description = "Sampling temperature to use." ,
238245 )
239246 top_p : Optional [float ] = Field (
240247 default = None ,
241- description = "The top p to use for the chat completion ." ,
248+ description = "Top-p sampling parameter ." ,
242249 )
243- stop : Optional [Union [ str , List [ str ]] ] = Field (
250+ stop : Optional [str ] = Field (
244251 default = None ,
245- description = "The stop sequences to use for the chat completion ." ,
252+ description = "Stop sequences." ,
246253 )
247254 presence_penalty : Optional [float ] = Field (
248255 default = None ,
249- description = "The presence penalty to use for the chat completion ." ,
256+ description = "Presence penalty parameter ." ,
250257 )
251258 frequency_penalty : Optional [float ] = Field (
252259 default = None ,
253- description = "The frequency penalty to use for the chat completion ." ,
260+ description = "Frequency penalty parameter ." ,
254261 )
255- messages : Optional [List [dict ]] = Field (
256- default = None , description = "The list of messages in the current conversation."
257- )
258- context : Optional [dict ] = Field (
262+ function_call : Optional [dict ] = Field (
259263 default = None ,
260- description = "Additional context data to be added to the conversation ." ,
264+ description = "Function call parameter ." ,
261265 )
262- stream : Optional [bool ] = Field (
263- default = False ,
264- description = "If set, partial message deltas will be sent, like in ChatGPT. "
265- "Tokens will be sent as data-only server-sent events as they become "
266- "available, with the stream terminated by a data: [DONE] message." ,
267- )
268- options : GenerationOptions = Field (
269- default_factory = GenerationOptions ,
270- description = "Additional options for controlling the generation." ,
266+ logit_bias : Optional [dict ] = Field (
267+ default = None ,
268+ description = "Logit bias parameter." ,
271269 )
272- state : Optional [dict ] = Field (
270+ log_probs : Optional [bool ] = Field (
273271 default = None ,
274- description = "A state object that should be used to continue the interaction ." ,
272+ description = "Log probabilities parameter ." ,
275273 )
276274
277275 @root_validator (pre = True )
@@ -453,7 +451,7 @@ async def _format_streaming_response(
453451 "choices" : [
454452 {
455453 "delta" : chunk ,
456- "index" : None ,
454+ "index" : 0 ,
457455 "finish_reason" : None ,
458456 }
459457 ],
@@ -472,7 +470,7 @@ async def _format_streaming_response(
472470 "choices" : [
473471 {
474472 "delta" : {"content" : chunk },
475- "index" : None ,
473+ "index" : 0 ,
476474 "finish_reason" : None ,
477475 }
478476 ],
@@ -487,7 +485,7 @@ async def _format_streaming_response(
487485 "choices" : [
488486 {
489487 "delta" : {"content" : str (chunk )},
490- "index" : None ,
488+ "index" : 0 ,
491489 "finish_reason" : None ,
492490 }
493491 ],
@@ -536,16 +534,16 @@ async def chat_completion(body: RequestBody, request: Request):
536534 id = f"chatcmpl-{ uuid .uuid4 ()} " ,
537535 object = "chat.completion" ,
538536 created = int (time .time ()),
539- model = config_ids [0 ] if config_ids else None ,
537+ model = config_ids [0 ] if config_ids else "unknown" ,
540538 choices = [
541539 Choice (
542540 index = 0 ,
543- message = {
544- " content" : f"Could not load the { config_ids } guardrails configuration. "
541+ message = ChatCompletionMessage (
542+ content = f"Could not load the { config_ids } guardrails configuration. "
545543 f"An internal error has occurred." ,
546- " role" : "assistant" ,
547- } ,
548- finish_reason = "error " ,
544+ role = "assistant" ,
545+ ) ,
546+ finish_reason = "stop " ,
549547 logprobs = None ,
550548 )
551549 ],
@@ -569,15 +567,15 @@ async def chat_completion(body: RequestBody, request: Request):
569567 id = f"chatcmpl-{ uuid .uuid4 ()} " ,
570568 object = "chat.completion" ,
571569 created = int (time .time ()),
572- model = None ,
570+ model = config_ids [ 0 ] if config_ids else "unknown" ,
573571 choices = [
574572 Choice (
575573 index = 0 ,
576- message = {
577- " content" : "The `thread_id` must have a minimum length of 16 characters." ,
578- " role" : "assistant" ,
579- } ,
580- finish_reason = "error " ,
574+ message = ChatCompletionMessage (
575+ content = "The `thread_id` must have a minimum length of 16 characters." ,
576+ role = "assistant" ,
577+ ) ,
578+ finish_reason = "stop " ,
581579 logprobs = None ,
582580 )
583581 ],
@@ -661,11 +659,14 @@ async def chat_completion(body: RequestBody, request: Request):
661659 "id" : f"chatcmpl-{ uuid .uuid4 ()} " ,
662660 "object" : "chat.completion" ,
663661 "created" : int (time .time ()),
664- "model" : config_ids [0 ] if config_ids else None ,
662+ "model" : config_ids [0 ] if config_ids else "unknown" ,
665663 "choices" : [
666664 Choice (
667665 index = 0 ,
668- message = bot_message ,
666+ message = ChatCompletionMessage (
667+ role = "assistant" ,
668+ content = bot_message ["content" ],
669+ ),
669670 finish_reason = "stop" ,
670671 logprobs = None ,
671672 )
@@ -687,15 +688,15 @@ async def chat_completion(body: RequestBody, request: Request):
687688 id = f"chatcmpl-{ uuid .uuid4 ()} " ,
688689 object = "chat.completion" ,
689690 created = int (time .time ()),
690- model = None ,
691+ model = "unknown" ,
691692 choices = [
692693 Choice (
693694 index = 0 ,
694- message = {
695- " content" : "Internal server error" ,
696- " role" : "assistant" ,
697- } ,
698- finish_reason = "error " ,
695+ message = ChatCompletionMessage (
696+ content = "Internal server error" ,
697+ role = "assistant" ,
698+ ) ,
699+ finish_reason = "stop " ,
699700 logprobs = None ,
700701 )
701702 ],
0 commit comments