Skip to content

Commit aaac161

Browse files
Extend existing OpenAI types and add support for streaming chat completion
1 parent 4cdf232 commit aaac161

File tree

7 files changed

+484
-119
lines changed

7 files changed

+484
-119
lines changed

nemoguardrails/server/api.py

Lines changed: 123 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,20 @@
2424
import uuid
2525
import warnings
2626
from contextlib import asynccontextmanager
27-
from typing import Any, Callable, List, Optional
27+
from typing import Any, AsyncIterator, Callable, List, Optional, Union
2828

2929
from fastapi import FastAPI, Request
3030
from fastapi.middleware.cors import CORSMiddleware
31-
from pydantic import Field, root_validator, validator
31+
from openai.types.chat.chat_completion import ChatCompletion, Choice
32+
from openai.types.model import Model
33+
from pydantic import BaseModel, Field, root_validator, validator
3234
from starlette.responses import StreamingResponse
3335
from starlette.staticfiles import StaticFiles
3436

3537
from nemoguardrails import LLMRails, RailsConfig, utils
3638
from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse
3739
from nemoguardrails.server.datastore.datastore import DataStore
38-
from nemoguardrails.server.schemas.openai import (
39-
Choice,
40-
Model,
41-
ModelsResponse,
42-
OpenAIRequestFields,
43-
ResponseBody,
44-
)
40+
from nemoguardrails.server.schemas.openai import ModelsResponse, ResponseBody
4541
from nemoguardrails.streaming import StreamingHandler
4642

4743
logging.basicConfig(level=logging.INFO)
@@ -195,7 +191,7 @@ async def root_handler():
195191
app.single_config_id = None
196192

197193

198-
class RequestBody(OpenAIRequestFields):
194+
class RequestBody(ChatCompletion):
199195
config_id: Optional[str] = Field(
200196
default=os.getenv("DEFAULT_CONFIG_ID", None),
201197
description="The id of the configuration to be used. If not set, the default configuration will be used.",
@@ -212,6 +208,50 @@ class RequestBody(OpenAIRequestFields):
212208
max_length=255,
213209
description="The id of an existing thread to which the messages should be added.",
214210
)
211+
model: Optional[str] = Field(
212+
default=None,
213+
description="The model used for the chat completion.",
214+
)
215+
id: Optional[str] = Field(
216+
default=None,
217+
description="The id of the chat completion.",
218+
)
219+
object: Optional[str] = Field(
220+
default="chat.completion",
221+
description="The object type, which is always chat.completion",
222+
)
223+
created: Optional[int] = Field(
224+
default=None,
225+
description="The Unix timestamp (in seconds) of when the chat completion was created.",
226+
)
227+
choices: Optional[List[Choice]] = Field(
228+
default=None,
229+
description="The list of choices for the chat completion.",
230+
)
231+
max_tokens: Optional[int] = Field(
232+
default=None,
233+
description="The maximum number of tokens to generate.",
234+
)
235+
temperature: Optional[float] = Field(
236+
default=None,
237+
description="The temperature to use for the chat completion.",
238+
)
239+
top_p: Optional[float] = Field(
240+
default=None,
241+
description="The top p to use for the chat completion.",
242+
)
243+
stop: Optional[Union[str, List[str]]] = Field(
244+
default=None,
245+
description="The stop sequences to use for the chat completion.",
246+
)
247+
presence_penalty: Optional[float] = Field(
248+
default=None,
249+
description="The presence penalty to use for the chat completion.",
250+
)
251+
frequency_penalty: Optional[float] = Field(
252+
default=None,
253+
description="The frequency penalty to use for the chat completion.",
254+
)
215255
messages: Optional[List[dict]] = Field(
216256
default=None, description="The list of messages in the current conversation."
217257
)
@@ -391,6 +431,73 @@ def _get_rails(config_ids: List[str]) -> LLMRails:
391431
return llm_rails
392432

393433

434+
async def _format_streaming_response(
435+
streaming_handler: StreamingHandler, model_name: Optional[str]
436+
) -> AsyncIterator[str]:
437+
while True:
438+
try:
439+
chunk = await streaming_handler.__anext__()
440+
except StopAsyncIteration:
441+
# When the stream ends, yield the [DONE] message
442+
yield "data: [DONE]\n\n"
443+
break
444+
445+
# Determine the payload format based on chunk type
446+
if isinstance(chunk, dict):
447+
# If chunk is a dict, wrap it in OpenAI chunk format with delta
448+
payload = {
449+
"id": None,
450+
"object": "chat.completion.chunk",
451+
"created": int(time.time()),
452+
"model": model_name,
453+
"choices": [
454+
{
455+
"delta": chunk,
456+
"index": None,
457+
"finish_reason": None,
458+
}
459+
],
460+
}
461+
elif isinstance(chunk, str):
462+
try:
463+
# Try parsing as JSON - if it parses, it might be a pre-formed payload
464+
payload = json.loads(chunk)
465+
except Exception:
466+
# treat as plain text content token
467+
payload = {
468+
"id": None,
469+
"object": "chat.completion.chunk",
470+
"created": int(time.time()),
471+
"model": model_name,
472+
"choices": [
473+
{
474+
"delta": {"content": chunk},
475+
"index": None,
476+
"finish_reason": None,
477+
}
478+
],
479+
}
480+
else:
481+
# For any other type, treat as plain content
482+
payload = {
483+
"id": None,
484+
"object": "chat.completion.chunk",
485+
"created": int(time.time()),
486+
"model": model_name,
487+
"choices": [
488+
{
489+
"delta": {"content": str(chunk)},
490+
"index": None,
491+
"finish_reason": None,
492+
}
493+
],
494+
}
495+
496+
# Send the payload as JSON
497+
data = json.dumps(payload, ensure_ascii=False)
498+
yield f"data: {data}\n\n"
499+
500+
394501
@app.post(
395502
"/v1/chat/completions",
396503
response_model=ResponseBody,
@@ -522,7 +629,12 @@ async def chat_completion(body: RequestBody, request: Request):
522629
)
523630
)
524631

525-
return StreamingResponse(streaming_handler)
632+
return StreamingResponse(
633+
_format_streaming_response(
634+
streaming_handler, model_name=config_ids[0] if config_ids else None
635+
),
636+
media_type="text/event-stream",
637+
)
526638
else:
527639
res = await llm_rails.generate_async(
528640
messages=messages, options=generation_options, state=body.state

nemoguardrails/server/schemas/openai.py

Lines changed: 4 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -15,96 +15,16 @@
1515

1616
"""OpenAI API schema definitions for the NeMo Guardrails server."""
1717

18-
from typing import List, Optional, Union
18+
from typing import List, Optional
1919

20+
from openai.types.chat.chat_completion import ChatCompletion, Choice
21+
from openai.types.model import Model
2022
from pydantic import BaseModel, Field
2123

2224

23-
class OpenAIRequestFields(BaseModel):
24-
"""OpenAI API request fields that can be mixed into other request schemas."""
25-
26-
# Standard OpenAI completion parameters
27-
model: Optional[str] = Field(
28-
default=None,
29-
description="The model to use for chat completion. Maps to config_id for backward compatibility.",
30-
)
31-
max_tokens: Optional[int] = Field(
32-
default=None,
33-
description="The maximum number of tokens to generate.",
34-
)
35-
temperature: Optional[float] = Field(
36-
default=None,
37-
description="Sampling temperature to use.",
38-
)
39-
top_p: Optional[float] = Field(
40-
default=None,
41-
description="Top-p sampling parameter.",
42-
)
43-
stop: Optional[Union[str, List[str]]] = Field(
44-
default=None,
45-
description="Stop sequences.",
46-
)
47-
presence_penalty: Optional[float] = Field(
48-
default=None,
49-
description="Presence penalty parameter.",
50-
)
51-
frequency_penalty: Optional[float] = Field(
52-
default=None,
53-
description="Frequency penalty parameter.",
54-
)
55-
function_call: Optional[dict] = Field(
56-
default=None,
57-
description="Function call parameter.",
58-
)
59-
logit_bias: Optional[dict] = Field(
60-
default=None,
61-
description="Logit bias parameter.",
62-
)
63-
log_probs: Optional[bool] = Field(
64-
default=None,
65-
description="Log probabilities parameter.",
66-
)
67-
68-
69-
class Choice(BaseModel):
70-
"""OpenAI API choice structure in chat completion responses."""
71-
72-
index: Optional[int] = Field(
73-
default=None, description="The index of the choice in the list of choices."
74-
)
75-
message: Optional[dict] = Field(
76-
default=None, description="The message of the choice"
77-
)
78-
logprobs: Optional[dict] = Field(
79-
default=None, description="The log probabilities of the choice"
80-
)
81-
finish_reason: Optional[str] = Field(
82-
default=None, description="The reason the model stopped generating tokens."
83-
)
84-
85-
86-
class ResponseBody(BaseModel):
25+
class ResponseBody(ChatCompletion):
8726
"""OpenAI API response body with NeMo-Guardrails extensions."""
8827

89-
# OpenAI API fields
90-
id: Optional[str] = Field(
91-
default=None, description="A unique identifier for the chat completion."
92-
)
93-
object: str = Field(
94-
default="chat.completion",
95-
description="The object type, which is always chat.completion",
96-
)
97-
created: Optional[int] = Field(
98-
default=None,
99-
description="The Unix timestamp (in seconds) of when the chat completion was created.",
100-
)
101-
model: Optional[str] = Field(
102-
default=None, description="The model used for the chat completion."
103-
)
104-
choices: Optional[List[Choice]] = Field(
105-
default=None, description="A list of chat completion choices."
106-
)
107-
# NeMo-Guardrails specific fields for backward compatibility
10828
state: Optional[dict] = Field(
10929
default=None, description="State object for continuing the conversation."
11030
)
@@ -117,23 +37,6 @@ class ResponseBody(BaseModel):
11737
log: Optional[dict] = Field(default=None, description="Generation log data.")
11838

11939

120-
class Model(BaseModel):
121-
"""OpenAI API model representation."""
122-
123-
id: str = Field(
124-
description="The model identifier, which can be referenced in the API endpoints."
125-
)
126-
object: str = Field(
127-
default="model", description="The object type, which is always 'model'."
128-
)
129-
created: int = Field(
130-
description="The Unix timestamp (in seconds) of when the model was created."
131-
)
132-
owned_by: str = Field(
133-
default="nemo-guardrails", description="The organization that owns the model."
134-
)
135-
136-
13740
class ModelsResponse(BaseModel):
13841
"""OpenAI API models list response."""
13942

nemoguardrails/streaming.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,18 +173,39 @@ async def __anext__(self):
173173

174174
async def _process(
175175
self,
176-
chunk: Union[str, object],
176+
chunk: Union[str, dict, object],
177177
generation_info: Optional[Dict[str, Any]] = None,
178178
):
179-
"""Process a chunk of text.
179+
"""Process a chunk of text or dict.
180180
181181
If we're in buffering mode, record the text.
182182
Otherwise, update the full completion, check for stop tokens, and enqueue the chunk.
183+
Dict chunks bypass completion tracking and go directly to the queue.
183184
"""
184185

185186
if self.include_generation_metadata and generation_info:
186187
self.current_generation_info = generation_info
187188

189+
# Dict chunks bypass buffering and completion tracking
190+
if isinstance(chunk, dict):
191+
if self.pipe_to:
192+
asyncio.create_task(self.pipe_to.push_chunk(chunk))
193+
else:
194+
if self.include_generation_metadata:
195+
await self.queue.put(
196+
{
197+
"text": chunk,
198+
"generation_info": (
199+
self.current_generation_info.copy()
200+
if self.current_generation_info
201+
else {}
202+
),
203+
}
204+
)
205+
else:
206+
await self.queue.put(chunk)
207+
return
208+
188209
if self.enable_buffer:
189210
if chunk is not END_OF_STREAM:
190211
self.buffer += chunk if chunk is not None else ""
@@ -254,10 +275,28 @@ async def _process(
254275

255276
async def push_chunk(
256277
self,
257-
chunk: Union[str, GenerationChunk, AIMessageChunk, ChatGenerationChunk, None],
278+
chunk: Union[
279+
str,
280+
dict,
281+
GenerationChunk,
282+
AIMessageChunk,
283+
ChatGenerationChunk,
284+
None,
285+
object,
286+
],
258287
generation_info: Optional[Dict[str, Any]] = None,
259288
):
260-
"""Push a new chunk to the stream."""
289+
"""Push a new chunk to the stream.
290+
291+
Args:
292+
chunk: The chunk to push. Can be:
293+
- str: Plain text content
294+
- dict: Dictionary with fields like role, content, etc.
295+
- GenerationChunk/AIMessageChunk/ChatGenerationChunk: LangChain chunk types
296+
- None: Signals end of stream (converted to END_OF_STREAM)
297+
- object: END_OF_STREAM sentinel
298+
generation_info: Optional metadata about the generation
299+
"""
261300

262301
# if generation_info is not explicitly passed,
263302
# try to get it from the chunk itself if it's a GenerationChunk or ChatGenerationChunk
@@ -281,6 +320,9 @@ async def push_chunk(
281320
elif isinstance(chunk, str):
282321
# empty string is a valid chunk and should be processed normally
283322
pass
323+
elif isinstance(chunk, dict):
324+
# plain dict chunks are allowed (e.g., for OpenAI-compatible streaming)
325+
pass
284326
else:
285327
raise Exception(f"Unsupported chunk type: {chunk.__class__.__name__}")
286328

@@ -291,6 +333,11 @@ async def push_chunk(
291333
if self.include_generation_metadata and generation_info:
292334
self.current_generation_info = generation_info
293335

336+
# Dict chunks bypass prefix/suffix processing and go directly to _process
337+
if isinstance(chunk, dict):
338+
await self._process(chunk, generation_info)
339+
return
340+
294341
# Process prefix: accumulate until the expected prefix is received, then remove it.
295342
if self.prefix:
296343
if chunk is not None and chunk is not END_OF_STREAM:

0 commit comments

Comments
 (0)