Skip to content

Commit 7071697

Browse files
AryanBagadermccorm4
authored andcommitted
feat: Add logprobs support to vLLM backend (ai-dynamo#4683) (ai-dynamo#4697)
Signed-off-by: Aryan Bagade <[email protected]> Signed-off-by: Aryan Bagade <[email protected]> Co-authored-by: Ryan McCormick <[email protected]>
1 parent 951989d commit 7071697

File tree

4 files changed

+239
-1
lines changed

4 files changed

+239
-1
lines changed

components/src/dynamo/vllm/handlers.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def build_sampling_params(
7373
Build SamplingParams from a PreprocessedRequest.
7474
7575
Args:
76-
request: The PreprocessedRequest dict with 'sampling_options' and 'stop_conditions'
76+
request: The PreprocessedRequest dict with 'sampling_options', 'stop_conditions',
77+
and 'output_options'
7778
default_sampling_params: Default sampling parameters to initialize with
7879
7980
Returns:
@@ -116,6 +117,41 @@ def build_sampling_params(
116117
existing = sampling_params.stop_token_ids or []
117118
sampling_params.stop_token_ids = list(set(existing).union(value))
118119

120+
# Apply output_options (logprobs, prompt_logprobs, etc.)
121+
output_options = request.get("output_options", {})
122+
if output_options:
123+
# Handle logprobs - vLLM expects this as an integer or None
124+
logprobs_value = output_options.get("logprobs")
125+
if logprobs_value is not None and logprobs_value != "":
126+
try:
127+
parsed_logprobs = int(logprobs_value)
128+
if parsed_logprobs < 0:
129+
logger.warning(
130+
f"Invalid logprobs value: {logprobs_value} (must be non-negative), ignoring"
131+
)
132+
else:
133+
sampling_params.logprobs = parsed_logprobs
134+
except (ValueError, TypeError):
135+
logger.warning(
136+
f"Invalid logprobs value: {logprobs_value} (must be integer), ignoring"
137+
)
138+
139+
# Handle prompt_logprobs - vLLM expects this as an integer or None
140+
prompt_logprobs_value = output_options.get("prompt_logprobs")
141+
if prompt_logprobs_value is not None and prompt_logprobs_value != "":
142+
try:
143+
parsed_prompt_logprobs = int(prompt_logprobs_value)
144+
if parsed_prompt_logprobs < 0:
145+
logger.warning(
146+
f"Invalid prompt_logprobs value: {prompt_logprobs_value} (must be non-negative), ignoring"
147+
)
148+
else:
149+
sampling_params.prompt_logprobs = parsed_prompt_logprobs
150+
except (ValueError, TypeError):
151+
logger.warning(
152+
f"Invalid prompt_logprobs value: {prompt_logprobs_value} (must be integer), ignoring"
153+
)
154+
119155
# If max_tokens wasn't provided (None or missing), compute a dynamic default
120156
provided_max_tokens = request.get("stop_conditions", {}).get("max_tokens", None)
121157
token_ids = request.get("token_ids", [])
@@ -577,6 +613,66 @@ def _build_completion_usage(request_output: RequestOutput) -> Dict[str, Any]:
577613
),
578614
}
579615

616+
@staticmethod
617+
def _extract_logprobs(
618+
output, num_output_tokens_so_far: int
619+
) -> tuple[list[float] | None, list[list[dict]] | None]:
620+
"""
621+
Extract logprobs from vLLM CompletionOutput for new tokens.
622+
623+
Args:
624+
output: vLLM CompletionOutput object
625+
num_output_tokens_so_far: Number of tokens already processed
626+
627+
Returns:
628+
Tuple of (log_probs, top_logprobs) in Dynamo's expected format:
629+
- log_probs: List of log probabilities for each new token
630+
- top_logprobs: List of top logprobs dicts for each new token
631+
"""
632+
if output.logprobs is None:
633+
return None, None
634+
635+
# Get logprobs for new tokens only
636+
new_logprobs = output.logprobs[num_output_tokens_so_far:]
637+
if not new_logprobs:
638+
return None, None
639+
640+
log_probs = []
641+
top_logprobs = []
642+
643+
for token_idx, token_logprobs_dict in enumerate(new_logprobs):
644+
if token_logprobs_dict is None:
645+
continue
646+
647+
# Get the actual token_id that was generated at this position
648+
actual_token_id = output.token_ids[num_output_tokens_so_far + token_idx]
649+
650+
# Extract log probability for the selected token
651+
# vLLM guarantees the selected token is always in the logprobs dict
652+
selected_logprob = token_logprobs_dict[actual_token_id]
653+
log_probs.append(float(selected_logprob.logprob))
654+
655+
# Build top_logprobs list for this token position
656+
token_top_logprobs = []
657+
for tok_id, logprob_info in token_logprobs_dict.items():
658+
token_top_logprobs.append(
659+
{
660+
"rank": (
661+
logprob_info.rank if hasattr(logprob_info, "rank") else 0
662+
),
663+
"token_id": tok_id,
664+
"token": (
665+
logprob_info.decoded_token
666+
if hasattr(logprob_info, "decoded_token")
667+
else None
668+
),
669+
"logprob": float(logprob_info.logprob),
670+
}
671+
)
672+
top_logprobs.append(token_top_logprobs)
673+
674+
return log_probs if log_probs else None, top_logprobs if top_logprobs else None
675+
580676
async def generate_tokens(
581677
self,
582678
prompt,
@@ -622,6 +718,16 @@ async def generate_tokens(
622718
output = res.outputs[0]
623719
next_total_toks = len(output.token_ids)
624720
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
721+
722+
# Extract logprobs for new tokens if available
723+
log_probs, top_logprobs = self._extract_logprobs(
724+
output, num_output_tokens_so_far
725+
)
726+
if log_probs is not None:
727+
out["log_probs"] = log_probs
728+
if top_logprobs is not None:
729+
out["top_logprobs"] = top_logprobs
730+
625731
if output.finish_reason:
626732
out["finish_reason"] = output.finish_reason
627733
out[

tests/serve/test_vllm.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
from tests.utils.payload_builder import (
2020
chat_payload,
2121
chat_payload_default,
22+
chat_payload_with_logprobs,
2223
completion_payload_default,
24+
completion_payload_with_logprobs,
2325
metric_payload_default,
2426
)
2527
from tests.utils.payloads import ToolCallingChatPayload
@@ -59,6 +61,29 @@ class VLLMConfig(EngineConfig):
5961
metric_payload_default(min_num_requests=6, backend="vllm"),
6062
],
6163
),
64+
"aggregated_logprobs": VLLMConfig(
65+
name="aggregated_logprobs",
66+
directory=vllm_dir,
67+
script_name="agg.sh",
68+
marks=[pytest.mark.gpu_1],
69+
model="Qwen/Qwen3-0.6B",
70+
request_payloads=[
71+
chat_payload_with_logprobs(
72+
repeat_count=2,
73+
expected_response=["AI", "knock", "joke"],
74+
max_tokens=30,
75+
temperature=0.0,
76+
top_logprobs=3,
77+
),
78+
completion_payload_with_logprobs(
79+
repeat_count=2,
80+
expected_response=["AI", "knock", "joke"],
81+
max_tokens=30,
82+
temperature=0.0,
83+
logprobs=5,
84+
),
85+
],
86+
),
6287
"aggregated_lmcache": VLLMConfig(
6388
name="aggregated_lmcache",
6489
directory=vllm_dir,

tests/utils/payload_builder.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ def chat_payload(
153153
}
154154
if temperature is not None:
155155
body["temperature"] = temperature
156+
if logprobs is not None:
157+
body["logprobs"] = logprobs
158+
if top_logprobs is not None:
159+
body["top_logprobs"] = top_logprobs
156160

157161
if top_logprobs is not None:
158162
body["top_logprobs"] = top_logprobs
@@ -307,3 +311,83 @@ def _check_completions_endpoint(remaining_timeout: float = 30.0) -> bool:
307311
return False
308312

309313
return _check_completions_endpoint
314+
315+
316+
def chat_payload_with_logprobs(
317+
content: Union[str, List[Dict[str, Any]]] = TEXT_PROMPT,
318+
repeat_count: int = 1,
319+
expected_response: Optional[List[str]] = None,
320+
max_tokens: int = 50,
321+
temperature: float = 0.0,
322+
top_logprobs: int = 3,
323+
) -> ChatPayloadWithLogprobs:
324+
"""
325+
Create a chat payload that requests and validates logprobs in the response.
326+
327+
Args:
328+
content: Message content (text or structured content list)
329+
repeat_count: Number of times to repeat the request
330+
expected_response: List of strings expected in the response text
331+
max_tokens: Maximum tokens to generate
332+
temperature: Sampling temperature
333+
top_logprobs: Number of top logprobs to return per token
334+
335+
Returns:
336+
ChatPayloadWithLogprobs that validates logprobs in response
337+
"""
338+
body: Dict[str, Any] = {
339+
"messages": [
340+
{
341+
"role": "user",
342+
"content": content,
343+
}
344+
],
345+
"max_tokens": max_tokens,
346+
"temperature": temperature,
347+
"logprobs": True,
348+
"top_logprobs": top_logprobs,
349+
}
350+
351+
return ChatPayloadWithLogprobs(
352+
body=body,
353+
repeat_count=repeat_count,
354+
expected_log=[],
355+
expected_response=expected_response or ["AI", "knock", "joke"],
356+
)
357+
358+
359+
def completion_payload_with_logprobs(
360+
prompt: str = TEXT_PROMPT,
361+
repeat_count: int = 1,
362+
expected_response: Optional[List[str]] = None,
363+
max_tokens: int = 50,
364+
temperature: float = 0.0,
365+
logprobs: int = 5,
366+
) -> CompletionPayloadWithLogprobs:
367+
"""
368+
Create a completion payload that requests and validates logprobs in the response.
369+
370+
Args:
371+
prompt: Text prompt
372+
repeat_count: Number of times to repeat the request
373+
expected_response: List of strings expected in the response text
374+
max_tokens: Maximum tokens to generate
375+
temperature: Sampling temperature
376+
logprobs: Number of logprobs to return per token
377+
378+
Returns:
379+
CompletionPayloadWithLogprobs that validates logprobs in response
380+
"""
381+
body: Dict[str, Any] = {
382+
"prompt": prompt,
383+
"max_tokens": max_tokens,
384+
"temperature": temperature,
385+
"logprobs": logprobs,
386+
}
387+
388+
return CompletionPayloadWithLogprobs(
389+
body=body,
390+
repeat_count=repeat_count,
391+
expected_log=[],
392+
expected_response=expected_response or ["AI", "knock", "joke"],
393+
)

tests/utils/payloads.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import logging
17+
import math
1718
import re
1819
import time
1920
from copy import deepcopy
@@ -183,6 +184,14 @@ def validate(self, response: Any, content: str) -> None:
183184
"top_logprobs" in item
184185
), "Missing 'top_logprobs' in logprobs content"
185186

187+
# Sanity check: logprob should be valid (not nan/inf/positive)
188+
logprob_val = item["logprob"]
189+
assert not math.isnan(logprob_val), "logprob is NaN"
190+
assert not math.isinf(logprob_val), "logprob is infinite"
191+
assert (
192+
logprob_val <= 0
193+
), f"logprob should be <= 0, got {logprob_val}"
194+
186195
logger.info(
187196
f"✓ Logprobs validation passed: found {len(content_logprobs)} tokens with logprobs"
188197
)
@@ -281,6 +290,20 @@ def validate(self, response: Any, content: str) -> None:
281290
assert len(token_logprobs) == len(
282291
tokens
283292
), "Mismatch between token_logprobs and tokens length"
293+
294+
# Sanity check: each logprob should be valid (not nan/inf/positive)
295+
for i, logprob_val in enumerate(token_logprobs):
296+
if logprob_val is not None: # First token can be None
297+
assert not math.isnan(
298+
logprob_val
299+
), f"logprob at index {i} is NaN"
300+
assert not math.isinf(
301+
logprob_val
302+
), f"logprob at index {i} is infinite"
303+
assert (
304+
logprob_val <= 0
305+
), f"logprob at index {i} should be <= 0, got {logprob_val}"
306+
284307
logger.info(
285308
f"✓ Logprobs validation passed: found {len(token_logprobs)} tokens with logprobs"
286309
)

0 commit comments

Comments
 (0)