Skip to content

Commit 3e4b480

Browse files
vladnosivkaren-sy
andauthored
fix: guided decoding params handling in vLLM (#4770)
Signed-off-by: Vladislav Nosivskoy <[email protected]> Co-authored-by: Karen Chung <[email protected]>
1 parent 197e022 commit 3e4b480

File tree

3 files changed

+80
-2
lines changed

3 files changed

+80
-2
lines changed

components/src/dynamo/vllm/handlers.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.inputs import TokensPrompt
1313
from vllm.lora.request import LoRARequest
1414
from vllm.outputs import RequestOutput
15-
from vllm.sampling_params import SamplingParams
15+
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
1616
from vllm.v1.engine.exceptions import EngineDeadError
1717

1818
from dynamo.llm import (
@@ -82,8 +82,22 @@ def build_sampling_params(
8282
sampling_params = SamplingParams(**default_sampling_params)
8383
sampling_params.detokenize = False
8484

85-
# Apply sampling_options
85+
# Handle guided_decoding - convert to StructuredOutputsParams
86+
guided_decoding = request["sampling_options"].get("guided_decoding")
87+
if guided_decoding is not None and isinstance(guided_decoding, dict):
88+
sampling_params.structured_outputs = StructuredOutputsParams(
89+
json=guided_decoding.get("json"),
90+
regex=guided_decoding.get("regex"),
91+
choice=guided_decoding.get("choice"),
92+
grammar=guided_decoding.get("grammar"),
93+
whitespace_pattern=guided_decoding.get("whitespace_pattern"),
94+
)
95+
96+
# Apply remaining sampling_options
8697
for key, value in request["sampling_options"].items():
98+
# Skip guided_decoding - already handled above
99+
if key == "guided_decoding":
100+
continue
87101
if value is not None and hasattr(sampling_params, key):
88102
setattr(sampling_params, key, value)
89103

tests/serve/test_vllm.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,66 @@ class VLLMConfig(EngineConfig):
454454
completion_payload_default(),
455455
],
456456
),
457+
"guided_decoding_json": VLLMConfig(
458+
name="guided_decoding_json",
459+
directory=vllm_dir,
460+
script_name="agg.sh",
461+
marks=[pytest.mark.gpu_1, pytest.mark.pre_merge],
462+
model="Qwen/Qwen3-0.6B",
463+
request_payloads=[
464+
chat_payload(
465+
"Generate a person with name and age",
466+
repeat_count=1,
467+
expected_response=['"name"', '"age"'],
468+
temperature=0.0,
469+
max_tokens=100,
470+
extra_body={
471+
"guided_json": {
472+
"type": "object",
473+
"properties": {
474+
"name": {"type": "string"},
475+
"age": {"type": "integer"},
476+
},
477+
"required": ["name", "age"],
478+
}
479+
},
480+
)
481+
],
482+
),
483+
"guided_decoding_regex": VLLMConfig(
484+
name="guided_decoding_regex",
485+
directory=vllm_dir,
486+
script_name="agg.sh",
487+
marks=[pytest.mark.gpu_1, pytest.mark.pre_merge],
488+
model="Qwen/Qwen3-0.6B",
489+
request_payloads=[
490+
chat_payload(
491+
"Generate a color name (red, blue, or green)",
492+
repeat_count=1,
493+
expected_response=["red", "blue", "green"],
494+
temperature=0.0,
495+
max_tokens=20,
496+
extra_body={"guided_regex": r"(red|blue|green)"},
497+
)
498+
],
499+
),
500+
"guided_decoding_choice": VLLMConfig(
501+
name="guided_decoding_choice",
502+
directory=vllm_dir,
503+
script_name="agg.sh",
504+
marks=[pytest.mark.gpu_1, pytest.mark.pre_merge],
505+
model="Qwen/Qwen3-0.6B",
506+
request_payloads=[
507+
chat_payload(
508+
"Generate a color name (red, blue, or green)",
509+
repeat_count=1,
510+
expected_response=["red", "blue", "green"],
511+
temperature=0.0,
512+
max_tokens=20,
513+
extra_body={"guided_choice": ["red", "blue", "green"]},
514+
)
515+
],
516+
),
457517
}
458518

459519

tests/utils/payload_builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def chat_payload(
134134
max_tokens: int = 300,
135135
temperature: Optional[float] = None,
136136
stream: bool = False,
137+
extra_body: Optional[Dict[str, Any]] = None,
137138
) -> ChatPayload:
138139
body: Dict[str, Any] = {
139140
"messages": [
@@ -148,6 +149,9 @@ def chat_payload(
148149
if temperature is not None:
149150
body["temperature"] = temperature
150151

152+
if extra_body:
153+
body.update(extra_body)
154+
151155
return ChatPayload(
152156
body=body,
153157
repeat_count=repeat_count,

0 commit comments

Comments
 (0)