Skip to content

Commit 3566e82

Browse files
committed
feat: Add logprobs support to vLLM backend (#4683)
Implement logprobs functionality for vLLM backend to pass log probability information from vLLM to the frontend. Signed-off-by: Aryan Bagade <[email protected]>
1 parent 71f94ed commit 3566e82

File tree

4 files changed

+280
-7
lines changed

4 files changed

+280
-7
lines changed

components/src/dynamo/vllm/handlers.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def build_sampling_params(
7373
Build SamplingParams from a PreprocessedRequest.
7474
7575
Args:
76-
request: The PreprocessedRequest dict with 'sampling_options' and 'stop_conditions'
76+
request: The PreprocessedRequest dict with 'sampling_options', 'stop_conditions',
77+
and 'output_options'
7778
default_sampling_params: Default sampling parameters to initialize with
7879
7980
Returns:
@@ -95,6 +96,19 @@ def build_sampling_params(
9596
continue
9697
setattr(sampling_params, key, value)
9798

99+
# Apply output_options (logprobs, prompt_logprobs, etc.)
100+
output_options = request.get("output_options", {})
101+
if output_options:
102+
# Handle logprobs - vLLM expects this as an integer or None
103+
logprobs_value = output_options.get("logprobs")
104+
if logprobs_value is not None:
105+
sampling_params.logprobs = int(logprobs_value)
106+
107+
# Handle prompt_logprobs - vLLM expects this as an integer or None
108+
prompt_logprobs_value = output_options.get("prompt_logprobs")
109+
if prompt_logprobs_value is not None:
110+
sampling_params.prompt_logprobs = int(prompt_logprobs_value)
111+
98112
# If max_tokens wasn't provided (None or missing), compute a dynamic default
99113
provided_max_tokens = request.get("stop_conditions", {}).get("max_tokens", None)
100114
token_ids = request.get("token_ids", [])
@@ -556,6 +570,71 @@ def _build_completion_usage(request_output: RequestOutput) -> Dict[str, Any]:
556570
),
557571
}
558572

573+
@staticmethod
574+
def _extract_logprobs(
575+
output, num_output_tokens_so_far: int
576+
) -> tuple[list[float] | None, list[list[dict]] | None]:
577+
"""
578+
Extract logprobs from vLLM CompletionOutput for new tokens.
579+
580+
Args:
581+
output: vLLM CompletionOutput object
582+
num_output_tokens_so_far: Number of tokens already processed
583+
584+
Returns:
585+
Tuple of (log_probs, top_logprobs) in Dynamo's expected format:
586+
- log_probs: List of log probabilities for each new token
587+
- top_logprobs: List of top logprobs dicts for each new token
588+
"""
589+
if output.logprobs is None:
590+
return None, None
591+
592+
# Get logprobs for new tokens only
593+
new_logprobs = output.logprobs[num_output_tokens_so_far:]
594+
if not new_logprobs:
595+
return None, None
596+
597+
log_probs = []
598+
top_logprobs = []
599+
600+
for token_idx, token_logprobs_dict in enumerate(new_logprobs):
601+
if token_logprobs_dict is None:
602+
continue
603+
604+
# Get the actual token_id that was generated at this position
605+
actual_token_id = output.token_ids[num_output_tokens_so_far + token_idx]
606+
607+
# Extract log probability for the selected token
608+
if actual_token_id in token_logprobs_dict:
609+
selected_logprob = token_logprobs_dict[actual_token_id]
610+
log_probs.append(float(selected_logprob.logprob))
611+
else:
612+
# Fallback: use the first logprob if selected token not found
613+
first_logprob = next(iter(token_logprobs_dict.values()), None)
614+
if first_logprob:
615+
log_probs.append(float(first_logprob.logprob))
616+
617+
# Build top_logprobs list for this token position
618+
token_top_logprobs = []
619+
for tok_id, logprob_info in token_logprobs_dict.items():
620+
token_top_logprobs.append(
621+
{
622+
"rank": logprob_info.rank
623+
if hasattr(logprob_info, "rank")
624+
else 0,
625+
"token_id": tok_id,
626+
"token": (
627+
logprob_info.decoded_token
628+
if hasattr(logprob_info, "decoded_token")
629+
else None
630+
),
631+
"logprob": float(logprob_info.logprob),
632+
}
633+
)
634+
top_logprobs.append(token_top_logprobs)
635+
636+
return log_probs if log_probs else None, top_logprobs if top_logprobs else None
637+
559638
async def generate_tokens(
560639
self,
561640
prompt,
@@ -601,6 +680,16 @@ async def generate_tokens(
601680
output = res.outputs[0]
602681
next_total_toks = len(output.token_ids)
603682
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
683+
684+
# Extract logprobs for new tokens if available
685+
log_probs, top_logprobs = self._extract_logprobs(
686+
output, num_output_tokens_so_far
687+
)
688+
if log_probs is not None:
689+
out["log_probs"] = log_probs
690+
if top_logprobs is not None:
691+
out["top_logprobs"] = top_logprobs
692+
604693
if output.finish_reason:
605694
out["finish_reason"] = output.finish_reason
606695
out[

tests/serve/test_vllm.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from tests.utils.payload_builder import (
1919
chat_payload,
2020
chat_payload_default,
21+
chat_payload_with_logprobs,
2122
completion_payload_default,
23+
completion_payload_with_logprobs,
2224
metric_payload_default,
2325
)
2426

@@ -51,6 +53,29 @@ class VLLMConfig(EngineConfig):
5153
metric_payload_default(min_num_requests=6, backend="vllm"),
5254
],
5355
),
56+
"aggregated_logprobs": VLLMConfig(
57+
name="aggregated_logprobs",
58+
directory=vllm_dir,
59+
script_name="agg.sh",
60+
marks=[pytest.mark.gpu_1],
61+
model="Qwen/Qwen3-0.6B",
62+
request_payloads=[
63+
chat_payload_with_logprobs(
64+
repeat_count=2,
65+
expected_response=["AI", "knock", "joke"],
66+
max_tokens=30,
67+
temperature=0.0,
68+
top_logprobs=3,
69+
),
70+
completion_payload_with_logprobs(
71+
repeat_count=2,
72+
expected_response=["AI", "knock", "joke"],
73+
max_tokens=30,
74+
temperature=0.0,
75+
logprobs=5,
76+
),
77+
],
78+
),
5479
"aggregated_lmcache": VLLMConfig(
5580
name="aggregated_lmcache",
5681
directory=vllm_dir,

tests/utils/payload_builder.py

Lines changed: 99 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from tests.utils.client import send_request
77
from tests.utils.payloads import (
88
ChatPayload,
9+
ChatPayloadWithLogprobs,
910
CompletionPayload,
11+
CompletionPayloadWithLogprobs,
1012
EmbeddingPayload,
1113
MetricsPayload,
1214
)
@@ -134,6 +136,8 @@ def chat_payload(
134136
max_tokens: int = 300,
135137
temperature: Optional[float] = None,
136138
stream: bool = False,
139+
logprobs: Optional[int] = None,
140+
top_logprobs: Optional[int] = None,
137141
) -> ChatPayload:
138142
body: Dict[str, Any] = {
139143
"messages": [
@@ -147,6 +151,10 @@ def chat_payload(
147151
}
148152
if temperature is not None:
149153
body["temperature"] = temperature
154+
if logprobs is not None:
155+
body["logprobs"] = logprobs
156+
if top_logprobs is not None:
157+
body["top_logprobs"] = top_logprobs
150158

151159
return ChatPayload(
152160
body=body,
@@ -164,14 +172,19 @@ def completion_payload(
164172
max_tokens: int = 150,
165173
temperature: float = 0.1,
166174
stream: bool = False,
175+
logprobs: Optional[int] = None,
167176
) -> CompletionPayload:
177+
body: Dict[str, Any] = {
178+
"prompt": prompt,
179+
"max_tokens": max_tokens,
180+
"temperature": temperature,
181+
"stream": stream,
182+
}
183+
if logprobs is not None:
184+
body["logprobs"] = logprobs
185+
168186
return CompletionPayload(
169-
body={
170-
"prompt": prompt,
171-
"max_tokens": max_tokens,
172-
"temperature": temperature,
173-
"stream": stream,
174-
},
187+
body=body,
175188
repeat_count=repeat_count,
176189
expected_log=expected_log or [],
177190
expected_response=expected_response or [],
@@ -276,3 +289,83 @@ def _check_completions_endpoint(remaining_timeout: float = 30.0) -> bool:
276289
return False
277290

278291
return _check_completions_endpoint
292+
293+
294+
def chat_payload_with_logprobs(
295+
content: Union[str, List[Dict[str, Any]]] = TEXT_PROMPT,
296+
repeat_count: int = 1,
297+
expected_response: Optional[List[str]] = None,
298+
max_tokens: int = 50,
299+
temperature: float = 0.0,
300+
top_logprobs: int = 3,
301+
) -> ChatPayloadWithLogprobs:
302+
"""
303+
Create a chat payload that requests and validates logprobs in the response.
304+
305+
Args:
306+
content: Message content (text or structured content list)
307+
repeat_count: Number of times to repeat the request
308+
expected_response: List of strings expected in the response text
309+
max_tokens: Maximum tokens to generate
310+
temperature: Sampling temperature
311+
top_logprobs: Number of top logprobs to return per token
312+
313+
Returns:
314+
ChatPayloadWithLogprobs that validates logprobs in response
315+
"""
316+
body: Dict[str, Any] = {
317+
"messages": [
318+
{
319+
"role": "user",
320+
"content": content,
321+
}
322+
],
323+
"max_tokens": max_tokens,
324+
"temperature": temperature,
325+
"logprobs": True,
326+
"top_logprobs": top_logprobs,
327+
}
328+
329+
return ChatPayloadWithLogprobs(
330+
body=body,
331+
repeat_count=repeat_count,
332+
expected_log=[],
333+
expected_response=expected_response or ["AI", "knock", "joke"],
334+
)
335+
336+
337+
def completion_payload_with_logprobs(
338+
prompt: str = TEXT_PROMPT,
339+
repeat_count: int = 1,
340+
expected_response: Optional[List[str]] = None,
341+
max_tokens: int = 50,
342+
temperature: float = 0.0,
343+
logprobs: int = 5,
344+
) -> CompletionPayloadWithLogprobs:
345+
"""
346+
Create a completion payload that requests and validates logprobs in the response.
347+
348+
Args:
349+
prompt: Text prompt
350+
repeat_count: Number of times to repeat the request
351+
expected_response: List of strings expected in the response text
352+
max_tokens: Maximum tokens to generate
353+
temperature: Sampling temperature
354+
logprobs: Number of logprobs to return per token
355+
356+
Returns:
357+
CompletionPayloadWithLogprobs that validates logprobs in response
358+
"""
359+
body: Dict[str, Any] = {
360+
"prompt": prompt,
361+
"max_tokens": max_tokens,
362+
"temperature": temperature,
363+
"logprobs": logprobs,
364+
}
365+
366+
return CompletionPayloadWithLogprobs(
367+
body=body,
368+
repeat_count=repeat_count,
369+
expected_log=[],
370+
expected_response=expected_response or ["AI", "knock", "joke"],
371+
)

tests/utils/payloads.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,39 @@ def response_handler(self, response: Any) -> str:
155155
return ChatPayload.extract_content(response)
156156

157157

158+
@dataclass
159+
class ChatPayloadWithLogprobs(ChatPayload):
160+
"""Chat payload that validates logprobs in response."""
161+
162+
def validate(self, response: Any, content: str) -> None:
163+
"""Validate response contains logprobs fields."""
164+
super().validate(response, content)
165+
166+
result = response.json()
167+
choice = result["choices"][0]
168+
169+
# Validate logprobs field exists
170+
assert "logprobs" in choice, "Missing 'logprobs' in choice"
171+
172+
logprobs_data = choice["logprobs"]
173+
if logprobs_data is not None:
174+
assert "content" in logprobs_data, "Missing 'content' in logprobs"
175+
content_logprobs = logprobs_data["content"]
176+
177+
if content_logprobs:
178+
# Validate structure of logprobs
179+
for item in content_logprobs:
180+
assert "token" in item, "Missing 'token' in logprobs content"
181+
assert "logprob" in item, "Missing 'logprob' in logprobs content"
182+
assert (
183+
"top_logprobs" in item
184+
), "Missing 'top_logprobs' in logprobs content"
185+
186+
logger.info(
187+
f"✓ Logprobs validation passed: found {len(content_logprobs)} tokens with logprobs"
188+
)
189+
190+
158191
@dataclass
159192
class CompletionPayload(BasePayload):
160193
"""Payload for completions endpoint."""
@@ -177,6 +210,39 @@ def response_handler(self, response: Any) -> str:
177210
return CompletionPayload.extract_text(response)
178211

179212

213+
@dataclass
214+
class CompletionPayloadWithLogprobs(CompletionPayload):
215+
"""Completion payload that validates logprobs in response."""
216+
217+
def validate(self, response: Any, content: str) -> None:
218+
"""Validate response contains logprobs fields."""
219+
super().validate(response, content)
220+
221+
result = response.json()
222+
choice = result["choices"][0]
223+
224+
# Validate logprobs field exists
225+
assert "logprobs" in choice, "Missing 'logprobs' in choice"
226+
227+
logprobs_data = choice["logprobs"]
228+
if logprobs_data is not None:
229+
assert (
230+
"token_logprobs" in logprobs_data
231+
), "Missing 'token_logprobs' in logprobs"
232+
assert "tokens" in logprobs_data, "Missing 'tokens' in logprobs"
233+
234+
token_logprobs = logprobs_data["token_logprobs"]
235+
tokens = logprobs_data["tokens"]
236+
237+
if token_logprobs:
238+
assert len(token_logprobs) == len(
239+
tokens
240+
), "Mismatch between token_logprobs and tokens length"
241+
logger.info(
242+
f"✓ Logprobs validation passed: found {len(token_logprobs)} tokens with logprobs"
243+
)
244+
245+
180246
@dataclass
181247
class EmbeddingPayload(BasePayload):
182248
"""Payload for embeddings endpoint."""

0 commit comments

Comments
 (0)