Skip to content
91 changes: 90 additions & 1 deletion components/src/dynamo/vllm/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def build_sampling_params(
Build SamplingParams from a PreprocessedRequest.

Args:
request: The PreprocessedRequest dict with 'sampling_options' and 'stop_conditions'
request: The PreprocessedRequest dict with 'sampling_options', 'stop_conditions',
and 'output_options'
default_sampling_params: Default sampling parameters to initialize with

Returns:
Expand Down Expand Up @@ -102,6 +103,19 @@ def build_sampling_params(
existing = sampling_params.stop_token_ids or []
sampling_params.stop_token_ids = list(set(existing).union(value))

# Apply output_options (logprobs, prompt_logprobs, etc.)
output_options = request.get("output_options", {})
if output_options:
# Handle logprobs - vLLM expects this as an integer or None
logprobs_value = output_options.get("logprobs")
if logprobs_value is not None:
sampling_params.logprobs = int(logprobs_value)

# Handle prompt_logprobs - vLLM expects this as an integer or None
prompt_logprobs_value = output_options.get("prompt_logprobs")
if prompt_logprobs_value is not None:
sampling_params.prompt_logprobs = int(prompt_logprobs_value)

# If max_tokens wasn't provided (None or missing), compute a dynamic default
provided_max_tokens = request.get("stop_conditions", {}).get("max_tokens", None)
token_ids = request.get("token_ids", [])
Expand Down Expand Up @@ -563,6 +577,71 @@ def _build_completion_usage(request_output: RequestOutput) -> Dict[str, Any]:
),
}

@staticmethod
def _extract_logprobs(
output, num_output_tokens_so_far: int
) -> tuple[list[float] | None, list[list[dict]] | None]:
"""
Extract logprobs from vLLM CompletionOutput for new tokens.

Args:
output: vLLM CompletionOutput object
num_output_tokens_so_far: Number of tokens already processed

Returns:
Tuple of (log_probs, top_logprobs) in Dynamo's expected format:
- log_probs: List of log probabilities for each new token
- top_logprobs: List of top logprobs dicts for each new token
"""
if output.logprobs is None:
return None, None

# Get logprobs for new tokens only
new_logprobs = output.logprobs[num_output_tokens_so_far:]
if not new_logprobs:
return None, None

log_probs = []
top_logprobs = []

for token_idx, token_logprobs_dict in enumerate(new_logprobs):
if token_logprobs_dict is None:
continue

# Get the actual token_id that was generated at this position
actual_token_id = output.token_ids[num_output_tokens_so_far + token_idx]

# Extract log probability for the selected token
if actual_token_id in token_logprobs_dict:
selected_logprob = token_logprobs_dict[actual_token_id]
log_probs.append(float(selected_logprob.logprob))
else:
# Fallback: use the first logprob if selected token not found
first_logprob = next(iter(token_logprobs_dict.values()), None)
if first_logprob:
log_probs.append(float(first_logprob.logprob))

# Build top_logprobs list for this token position
token_top_logprobs = []
for tok_id, logprob_info in token_logprobs_dict.items():
token_top_logprobs.append(
{
"rank": logprob_info.rank
if hasattr(logprob_info, "rank")
else 0,
"token_id": tok_id,
"token": (
logprob_info.decoded_token
if hasattr(logprob_info, "decoded_token")
else None
),
"logprob": float(logprob_info.logprob),
}
)
top_logprobs.append(token_top_logprobs)

return log_probs if log_probs else None, top_logprobs if top_logprobs else None

async def generate_tokens(
self,
prompt,
Expand Down Expand Up @@ -608,6 +687,16 @@ async def generate_tokens(
output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}

# Extract logprobs for new tokens if available
log_probs, top_logprobs = self._extract_logprobs(
output, num_output_tokens_so_far
)
if log_probs is not None:
out["log_probs"] = log_probs
if top_logprobs is not None:
out["top_logprobs"] = top_logprobs

if output.finish_reason:
out["finish_reason"] = output.finish_reason
out[
Expand Down
25 changes: 25 additions & 0 deletions tests/serve/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from tests.utils.payload_builder import (
chat_payload,
chat_payload_default,
chat_payload_with_logprobs,
completion_payload_default,
completion_payload_with_logprobs,
metric_payload_default,
)

Expand Down Expand Up @@ -52,6 +54,29 @@ class VLLMConfig(EngineConfig):
metric_payload_default(min_num_requests=6, backend="vllm"),
],
),
"aggregated_logprobs": VLLMConfig(
name="aggregated_logprobs",
directory=vllm_dir,
script_name="agg.sh",
marks=[pytest.mark.gpu_1],
model="Qwen/Qwen3-0.6B",
request_payloads=[
chat_payload_with_logprobs(
repeat_count=2,
expected_response=["AI", "knock", "joke"],
max_tokens=30,
temperature=0.0,
top_logprobs=3,
),
completion_payload_with_logprobs(
repeat_count=2,
expected_response=["AI", "knock", "joke"],
max_tokens=30,
temperature=0.0,
logprobs=5,
),
],
),
"aggregated_lmcache": VLLMConfig(
name="aggregated_lmcache",
directory=vllm_dir,
Expand Down
105 changes: 99 additions & 6 deletions tests/utils/payload_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from tests.utils.client import send_request
from tests.utils.payloads import (
ChatPayload,
ChatPayloadWithLogprobs,
CompletionPayload,
CompletionPayloadWithLogprobs,
EmbeddingPayload,
MetricsPayload,
)
Expand Down Expand Up @@ -134,6 +136,8 @@ def chat_payload(
max_tokens: int = 300,
temperature: Optional[float] = None,
stream: bool = False,
logprobs: Optional[int] = None,
top_logprobs: Optional[int] = None,
) -> ChatPayload:
body: Dict[str, Any] = {
"messages": [
Expand All @@ -147,6 +151,10 @@ def chat_payload(
}
if temperature is not None:
body["temperature"] = temperature
if logprobs is not None:
body["logprobs"] = logprobs
if top_logprobs is not None:
body["top_logprobs"] = top_logprobs

return ChatPayload(
body=body,
Expand All @@ -164,14 +172,19 @@ def completion_payload(
max_tokens: int = 150,
temperature: float = 0.1,
stream: bool = False,
logprobs: Optional[int] = None,
) -> CompletionPayload:
body: Dict[str, Any] = {
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"stream": stream,
}
if logprobs is not None:
body["logprobs"] = logprobs

return CompletionPayload(
body={
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"stream": stream,
},
body=body,
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response or [],
Expand Down Expand Up @@ -276,3 +289,83 @@ def _check_completions_endpoint(remaining_timeout: float = 30.0) -> bool:
return False

return _check_completions_endpoint


def chat_payload_with_logprobs(
content: Union[str, List[Dict[str, Any]]] = TEXT_PROMPT,
repeat_count: int = 1,
expected_response: Optional[List[str]] = None,
max_tokens: int = 50,
temperature: float = 0.0,
top_logprobs: int = 3,
) -> ChatPayloadWithLogprobs:
"""
Create a chat payload that requests and validates logprobs in the response.

Args:
content: Message content (text or structured content list)
repeat_count: Number of times to repeat the request
expected_response: List of strings expected in the response text
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_logprobs: Number of top logprobs to return per token

Returns:
ChatPayloadWithLogprobs that validates logprobs in response
"""
body: Dict[str, Any] = {
"messages": [
{
"role": "user",
"content": content,
}
],
"max_tokens": max_tokens,
"temperature": temperature,
"logprobs": True,
"top_logprobs": top_logprobs,
}

return ChatPayloadWithLogprobs(
body=body,
repeat_count=repeat_count,
expected_log=[],
expected_response=expected_response or ["AI", "knock", "joke"],
)


def completion_payload_with_logprobs(
prompt: str = TEXT_PROMPT,
repeat_count: int = 1,
expected_response: Optional[List[str]] = None,
max_tokens: int = 50,
temperature: float = 0.0,
logprobs: int = 5,
) -> CompletionPayloadWithLogprobs:
"""
Create a completion payload that requests and validates logprobs in the response.

Args:
prompt: Text prompt
repeat_count: Number of times to repeat the request
expected_response: List of strings expected in the response text
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
logprobs: Number of logprobs to return per token

Returns:
CompletionPayloadWithLogprobs that validates logprobs in response
"""
body: Dict[str, Any] = {
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"logprobs": logprobs,
}

return CompletionPayloadWithLogprobs(
body=body,
repeat_count=repeat_count,
expected_log=[],
expected_response=expected_response or ["AI", "knock", "joke"],
)
66 changes: 66 additions & 0 deletions tests/utils/payloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,39 @@ def response_handler(self, response: Any) -> str:
return ChatPayload.extract_content(response)


@dataclass
class ChatPayloadWithLogprobs(ChatPayload):
"""Chat payload that validates logprobs in response."""

def validate(self, response: Any, content: str) -> None:
"""Validate response contains logprobs fields."""
super().validate(response, content)

result = response.json()
choice = result["choices"][0]

# Validate logprobs field exists
assert "logprobs" in choice, "Missing 'logprobs' in choice"

logprobs_data = choice["logprobs"]
if logprobs_data is not None:
assert "content" in logprobs_data, "Missing 'content' in logprobs"
content_logprobs = logprobs_data["content"]

if content_logprobs:
# Validate structure of logprobs
for item in content_logprobs:
assert "token" in item, "Missing 'token' in logprobs content"
assert "logprob" in item, "Missing 'logprob' in logprobs content"
assert (
"top_logprobs" in item
), "Missing 'top_logprobs' in logprobs content"

logger.info(
f"✓ Logprobs validation passed: found {len(content_logprobs)} tokens with logprobs"
)


@dataclass
class CompletionPayload(BasePayload):
"""Payload for completions endpoint."""
Expand All @@ -177,6 +210,39 @@ def response_handler(self, response: Any) -> str:
return CompletionPayload.extract_text(response)


@dataclass
class CompletionPayloadWithLogprobs(CompletionPayload):
"""Completion payload that validates logprobs in response."""

def validate(self, response: Any, content: str) -> None:
"""Validate response contains logprobs fields."""
super().validate(response, content)

result = response.json()
choice = result["choices"][0]

# Validate logprobs field exists
assert "logprobs" in choice, "Missing 'logprobs' in choice"

logprobs_data = choice["logprobs"]
if logprobs_data is not None:
assert (
"token_logprobs" in logprobs_data
), "Missing 'token_logprobs' in logprobs"
assert "tokens" in logprobs_data, "Missing 'tokens' in logprobs"

token_logprobs = logprobs_data["token_logprobs"]
tokens = logprobs_data["tokens"]

if token_logprobs:
assert len(token_logprobs) == len(
tokens
), "Mismatch between token_logprobs and tokens length"
logger.info(
f"✓ Logprobs validation passed: found {len(token_logprobs)} tokens with logprobs"
)


@dataclass
class EmbeddingPayload(BasePayload):
"""Payload for embeddings endpoint."""
Expand Down
Loading