Skip to content
108 changes: 107 additions & 1 deletion components/src/dynamo/vllm/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def build_sampling_params(
Build SamplingParams from a PreprocessedRequest.

Args:
request: The PreprocessedRequest dict with 'sampling_options' and 'stop_conditions'
request: The PreprocessedRequest dict with 'sampling_options', 'stop_conditions',
and 'output_options'
default_sampling_params: Default sampling parameters to initialize with

Returns:
Expand Down Expand Up @@ -116,6 +117,41 @@ def build_sampling_params(
existing = sampling_params.stop_token_ids or []
sampling_params.stop_token_ids = list(set(existing).union(value))

# Apply output_options (logprobs, prompt_logprobs, etc.)
output_options = request.get("output_options", {})
if output_options:
# Handle logprobs - vLLM expects this as an integer or None
logprobs_value = output_options.get("logprobs")
if logprobs_value is not None and logprobs_value != "":
try:
parsed_logprobs = int(logprobs_value)
if parsed_logprobs < 0:
logger.warning(
f"Invalid logprobs value: {logprobs_value} (must be non-negative), ignoring"
)
else:
sampling_params.logprobs = parsed_logprobs
except (ValueError, TypeError):
logger.warning(
f"Invalid logprobs value: {logprobs_value} (must be integer), ignoring"
)

# Handle prompt_logprobs - vLLM expects this as an integer or None
prompt_logprobs_value = output_options.get("prompt_logprobs")
if prompt_logprobs_value is not None and prompt_logprobs_value != "":
try:
parsed_prompt_logprobs = int(prompt_logprobs_value)
if parsed_prompt_logprobs < 0:
logger.warning(
f"Invalid prompt_logprobs value: {prompt_logprobs_value} (must be non-negative), ignoring"
)
else:
sampling_params.prompt_logprobs = parsed_prompt_logprobs
except (ValueError, TypeError):
logger.warning(
f"Invalid prompt_logprobs value: {prompt_logprobs_value} (must be integer), ignoring"
)

# If max_tokens wasn't provided (None or missing), compute a dynamic default
provided_max_tokens = request.get("stop_conditions", {}).get("max_tokens", None)
token_ids = request.get("token_ids", [])
Expand Down Expand Up @@ -577,6 +613,66 @@ def _build_completion_usage(request_output: RequestOutput) -> Dict[str, Any]:
),
}

@staticmethod
def _extract_logprobs(
output, num_output_tokens_so_far: int
) -> tuple[list[float] | None, list[list[dict]] | None]:
"""
Extract logprobs from vLLM CompletionOutput for new tokens.

Args:
output: vLLM CompletionOutput object
num_output_tokens_so_far: Number of tokens already processed

Returns:
Tuple of (log_probs, top_logprobs) in Dynamo's expected format:
- log_probs: List of log probabilities for each new token
- top_logprobs: List of top logprobs dicts for each new token
"""
if output.logprobs is None:
return None, None

# Get logprobs for new tokens only
new_logprobs = output.logprobs[num_output_tokens_so_far:]
if not new_logprobs:
return None, None

log_probs = []
top_logprobs = []

for token_idx, token_logprobs_dict in enumerate(new_logprobs):
if token_logprobs_dict is None:
continue

# Get the actual token_id that was generated at this position
actual_token_id = output.token_ids[num_output_tokens_so_far + token_idx]

# Extract log probability for the selected token
# vLLM guarantees the selected token is always in the logprobs dict
selected_logprob = token_logprobs_dict[actual_token_id]
log_probs.append(float(selected_logprob.logprob))

# Build top_logprobs list for this token position
token_top_logprobs = []
for tok_id, logprob_info in token_logprobs_dict.items():
token_top_logprobs.append(
{
"rank": (
logprob_info.rank if hasattr(logprob_info, "rank") else 0
),
"token_id": tok_id,
"token": (
logprob_info.decoded_token
if hasattr(logprob_info, "decoded_token")
else None
),
"logprob": float(logprob_info.logprob),
}
)
top_logprobs.append(token_top_logprobs)

return log_probs if log_probs else None, top_logprobs if top_logprobs else None

async def generate_tokens(
self,
prompt,
Expand Down Expand Up @@ -622,6 +718,16 @@ async def generate_tokens(
output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}

# Extract logprobs for new tokens if available
log_probs, top_logprobs = self._extract_logprobs(
output, num_output_tokens_so_far
)
if log_probs is not None:
out["log_probs"] = log_probs
if top_logprobs is not None:
out["top_logprobs"] = top_logprobs

if output.finish_reason:
out["finish_reason"] = output.finish_reason
out[
Expand Down
25 changes: 25 additions & 0 deletions tests/serve/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from tests.utils.payload_builder import (
chat_payload,
chat_payload_default,
chat_payload_with_logprobs,
completion_payload_default,
completion_payload_with_logprobs,
metric_payload_default,
)
from tests.utils.payloads import ToolCallingChatPayload
Expand Down Expand Up @@ -59,6 +61,29 @@ class VLLMConfig(EngineConfig):
metric_payload_default(min_num_requests=6, backend="vllm"),
],
),
"aggregated_logprobs": VLLMConfig(
name="aggregated_logprobs",
directory=vllm_dir,
script_name="agg.sh",
marks=[pytest.mark.gpu_1],
model="Qwen/Qwen3-0.6B",
request_payloads=[
chat_payload_with_logprobs(
repeat_count=2,
expected_response=["AI", "knock", "joke"],
max_tokens=30,
temperature=0.0,
top_logprobs=3,
),
completion_payload_with_logprobs(
repeat_count=2,
expected_response=["AI", "knock", "joke"],
max_tokens=30,
temperature=0.0,
logprobs=5,
),
],
),
"aggregated_lmcache": VLLMConfig(
name="aggregated_lmcache",
directory=vllm_dir,
Expand Down
84 changes: 84 additions & 0 deletions tests/utils/payload_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ def chat_payload(
}
if temperature is not None:
body["temperature"] = temperature
if logprobs is not None:
body["logprobs"] = logprobs
if top_logprobs is not None:
body["top_logprobs"] = top_logprobs

if top_logprobs is not None:
body["top_logprobs"] = top_logprobs
Expand Down Expand Up @@ -307,3 +311,83 @@ def _check_completions_endpoint(remaining_timeout: float = 30.0) -> bool:
return False

return _check_completions_endpoint


def chat_payload_with_logprobs(
content: Union[str, List[Dict[str, Any]]] = TEXT_PROMPT,
repeat_count: int = 1,
expected_response: Optional[List[str]] = None,
max_tokens: int = 50,
temperature: float = 0.0,
top_logprobs: int = 3,
) -> ChatPayloadWithLogprobs:
"""
Create a chat payload that requests and validates logprobs in the response.

Args:
content: Message content (text or structured content list)
repeat_count: Number of times to repeat the request
expected_response: List of strings expected in the response text
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_logprobs: Number of top logprobs to return per token

Returns:
ChatPayloadWithLogprobs that validates logprobs in response
"""
body: Dict[str, Any] = {
"messages": [
{
"role": "user",
"content": content,
}
],
"max_tokens": max_tokens,
"temperature": temperature,
"logprobs": True,
"top_logprobs": top_logprobs,
}

return ChatPayloadWithLogprobs(
body=body,
repeat_count=repeat_count,
expected_log=[],
expected_response=expected_response or ["AI", "knock", "joke"],
)


def completion_payload_with_logprobs(
prompt: str = TEXT_PROMPT,
repeat_count: int = 1,
expected_response: Optional[List[str]] = None,
max_tokens: int = 50,
temperature: float = 0.0,
logprobs: int = 5,
) -> CompletionPayloadWithLogprobs:
"""
Create a completion payload that requests and validates logprobs in the response.

Args:
prompt: Text prompt
repeat_count: Number of times to repeat the request
expected_response: List of strings expected in the response text
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
logprobs: Number of logprobs to return per token

Returns:
CompletionPayloadWithLogprobs that validates logprobs in response
"""
body: Dict[str, Any] = {
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"logprobs": logprobs,
}

return CompletionPayloadWithLogprobs(
body=body,
repeat_count=repeat_count,
expected_log=[],
expected_response=expected_response or ["AI", "knock", "joke"],
)
23 changes: 23 additions & 0 deletions tests/utils/payloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import logging
import math
import re
import time
from copy import deepcopy
Expand Down Expand Up @@ -183,6 +184,14 @@ def validate(self, response: Any, content: str) -> None:
"top_logprobs" in item
), "Missing 'top_logprobs' in logprobs content"

# Sanity check: logprob should be valid (not nan/inf/positive)
logprob_val = item["logprob"]
assert not math.isnan(logprob_val), "logprob is NaN"
assert not math.isinf(logprob_val), "logprob is infinite"
assert (
logprob_val <= 0
), f"logprob should be <= 0, got {logprob_val}"

logger.info(
f"✓ Logprobs validation passed: found {len(content_logprobs)} tokens with logprobs"
)
Expand Down Expand Up @@ -281,6 +290,20 @@ def validate(self, response: Any, content: str) -> None:
assert len(token_logprobs) == len(
tokens
), "Mismatch between token_logprobs and tokens length"

# Sanity check: each logprob should be valid (not nan/inf/positive)
for i, logprob_val in enumerate(token_logprobs):
if logprob_val is not None: # First token can be None
assert not math.isnan(
logprob_val
), f"logprob at index {i} is NaN"
assert not math.isinf(
logprob_val
), f"logprob at index {i} is infinite"
assert (
logprob_val <= 0
), f"logprob at index {i} should be <= 0, got {logprob_val}"

logger.info(
f"✓ Logprobs validation passed: found {len(token_logprobs)} tokens with logprobs"
)
Expand Down
Loading