diff --git a/components/src/dynamo/vllm/handlers.py b/components/src/dynamo/vllm/handlers.py index 32fed9afc0..f1bc298755 100644 --- a/components/src/dynamo/vllm/handlers.py +++ b/components/src/dynamo/vllm/handlers.py @@ -73,7 +73,8 @@ def build_sampling_params( Build SamplingParams from a PreprocessedRequest. Args: - request: The PreprocessedRequest dict with 'sampling_options' and 'stop_conditions' + request: The PreprocessedRequest dict with 'sampling_options', 'stop_conditions', + and 'output_options' default_sampling_params: Default sampling parameters to initialize with Returns: @@ -116,6 +117,41 @@ def build_sampling_params( existing = sampling_params.stop_token_ids or [] sampling_params.stop_token_ids = list(set(existing).union(value)) + # Apply output_options (logprobs, prompt_logprobs, etc.) + output_options = request.get("output_options", {}) + if output_options: + # Handle logprobs - vLLM expects this as an integer or None + logprobs_value = output_options.get("logprobs") + if logprobs_value is not None and logprobs_value != "": + try: + parsed_logprobs = int(logprobs_value) + if parsed_logprobs < 0: + logger.warning( + f"Invalid logprobs value: {logprobs_value} (must be non-negative), ignoring" + ) + else: + sampling_params.logprobs = parsed_logprobs + except (ValueError, TypeError): + logger.warning( + f"Invalid logprobs value: {logprobs_value} (must be integer), ignoring" + ) + + # Handle prompt_logprobs - vLLM expects this as an integer or None + prompt_logprobs_value = output_options.get("prompt_logprobs") + if prompt_logprobs_value is not None and prompt_logprobs_value != "": + try: + parsed_prompt_logprobs = int(prompt_logprobs_value) + if parsed_prompt_logprobs < 0: + logger.warning( + f"Invalid prompt_logprobs value: {prompt_logprobs_value} (must be non-negative), ignoring" + ) + else: + sampling_params.prompt_logprobs = parsed_prompt_logprobs + except (ValueError, TypeError): + logger.warning( + f"Invalid prompt_logprobs value: {prompt_logprobs_value} (must be integer), ignoring" + ) + # If max_tokens wasn't provided (None or missing), compute a dynamic default provided_max_tokens = request.get("stop_conditions", {}).get("max_tokens", None) token_ids = request.get("token_ids", []) @@ -577,6 +613,66 @@ def _build_completion_usage(request_output: RequestOutput) -> Dict[str, Any]: ), } + @staticmethod + def _extract_logprobs( + output, num_output_tokens_so_far: int + ) -> tuple[list[float] | None, list[list[dict]] | None]: + """ + Extract logprobs from vLLM CompletionOutput for new tokens. + + Args: + output: vLLM CompletionOutput object + num_output_tokens_so_far: Number of tokens already processed + + Returns: + Tuple of (log_probs, top_logprobs) in Dynamo's expected format: + - log_probs: List of log probabilities for each new token + - top_logprobs: List of top logprobs dicts for each new token + """ + if output.logprobs is None: + return None, None + + # Get logprobs for new tokens only + new_logprobs = output.logprobs[num_output_tokens_so_far:] + if not new_logprobs: + return None, None + + log_probs = [] + top_logprobs = [] + + for token_idx, token_logprobs_dict in enumerate(new_logprobs): + if token_logprobs_dict is None: + continue + + # Get the actual token_id that was generated at this position + actual_token_id = output.token_ids[num_output_tokens_so_far + token_idx] + + # Extract log probability for the selected token + # vLLM guarantees the selected token is always in the logprobs dict + selected_logprob = token_logprobs_dict[actual_token_id] + log_probs.append(float(selected_logprob.logprob)) + + # Build top_logprobs list for this token position + token_top_logprobs = [] + for tok_id, logprob_info in token_logprobs_dict.items(): + token_top_logprobs.append( + { + "rank": ( + logprob_info.rank if hasattr(logprob_info, "rank") else 0 + ), + "token_id": tok_id, + "token": ( + logprob_info.decoded_token + if hasattr(logprob_info, "decoded_token") + else None + ), + "logprob": float(logprob_info.logprob), + } + ) + top_logprobs.append(token_top_logprobs) + + return log_probs if log_probs else None, top_logprobs if top_logprobs else None + async def generate_tokens( self, prompt, @@ -622,6 +718,16 @@ async def generate_tokens( output = res.outputs[0] next_total_toks = len(output.token_ids) out = {"token_ids": output.token_ids[num_output_tokens_so_far:]} + + # Extract logprobs for new tokens if available + log_probs, top_logprobs = self._extract_logprobs( + output, num_output_tokens_so_far + ) + if log_probs is not None: + out["log_probs"] = log_probs + if top_logprobs is not None: + out["top_logprobs"] = top_logprobs + if output.finish_reason: out["finish_reason"] = output.finish_reason out[ diff --git a/tests/serve/test_vllm.py b/tests/serve/test_vllm.py index 55873d07d0..c1f8887963 100644 --- a/tests/serve/test_vllm.py +++ b/tests/serve/test_vllm.py @@ -19,7 +19,9 @@ from tests.utils.payload_builder import ( chat_payload, chat_payload_default, + chat_payload_with_logprobs, completion_payload_default, + completion_payload_with_logprobs, metric_payload_default, ) from tests.utils.payloads import ToolCallingChatPayload @@ -59,6 +61,29 @@ class VLLMConfig(EngineConfig): metric_payload_default(min_num_requests=6, backend="vllm"), ], ), + "aggregated_logprobs": VLLMConfig( + name="aggregated_logprobs", + directory=vllm_dir, + script_name="agg.sh", + marks=[pytest.mark.gpu_1], + model="Qwen/Qwen3-0.6B", + request_payloads=[ + chat_payload_with_logprobs( + repeat_count=2, + expected_response=["AI", "knock", "joke"], + max_tokens=30, + temperature=0.0, + top_logprobs=3, + ), + completion_payload_with_logprobs( + repeat_count=2, + expected_response=["AI", "knock", "joke"], + max_tokens=30, + temperature=0.0, + logprobs=5, + ), + ], + ), "aggregated_lmcache": VLLMConfig( name="aggregated_lmcache", directory=vllm_dir, diff --git a/tests/utils/payload_builder.py b/tests/utils/payload_builder.py index d1e7549bae..017a35a461 100644 --- a/tests/utils/payload_builder.py +++ b/tests/utils/payload_builder.py @@ -153,6 +153,10 @@ def chat_payload( } if temperature is not None: body["temperature"] = temperature + if logprobs is not None: + body["logprobs"] = logprobs + if top_logprobs is not None: + body["top_logprobs"] = top_logprobs if top_logprobs is not None: body["top_logprobs"] = top_logprobs @@ -307,3 +311,83 @@ def _check_completions_endpoint(remaining_timeout: float = 30.0) -> bool: return False return _check_completions_endpoint + + +def chat_payload_with_logprobs( + content: Union[str, List[Dict[str, Any]]] = TEXT_PROMPT, + repeat_count: int = 1, + expected_response: Optional[List[str]] = None, + max_tokens: int = 50, + temperature: float = 0.0, + top_logprobs: int = 3, +) -> ChatPayloadWithLogprobs: + """ + Create a chat payload that requests and validates logprobs in the response. + + Args: + content: Message content (text or structured content list) + repeat_count: Number of times to repeat the request + expected_response: List of strings expected in the response text + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + top_logprobs: Number of top logprobs to return per token + + Returns: + ChatPayloadWithLogprobs that validates logprobs in response + """ + body: Dict[str, Any] = { + "messages": [ + { + "role": "user", + "content": content, + } + ], + "max_tokens": max_tokens, + "temperature": temperature, + "logprobs": True, + "top_logprobs": top_logprobs, + } + + return ChatPayloadWithLogprobs( + body=body, + repeat_count=repeat_count, + expected_log=[], + expected_response=expected_response or ["AI", "knock", "joke"], + ) + + +def completion_payload_with_logprobs( + prompt: str = TEXT_PROMPT, + repeat_count: int = 1, + expected_response: Optional[List[str]] = None, + max_tokens: int = 50, + temperature: float = 0.0, + logprobs: int = 5, +) -> CompletionPayloadWithLogprobs: + """ + Create a completion payload that requests and validates logprobs in the response. + + Args: + prompt: Text prompt + repeat_count: Number of times to repeat the request + expected_response: List of strings expected in the response text + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + logprobs: Number of logprobs to return per token + + Returns: + CompletionPayloadWithLogprobs that validates logprobs in response + """ + body: Dict[str, Any] = { + "prompt": prompt, + "max_tokens": max_tokens, + "temperature": temperature, + "logprobs": logprobs, + } + + return CompletionPayloadWithLogprobs( + body=body, + repeat_count=repeat_count, + expected_log=[], + expected_response=expected_response or ["AI", "knock", "joke"], + ) diff --git a/tests/utils/payloads.py b/tests/utils/payloads.py index 917ad36c0b..f8ad8162fb 100644 --- a/tests/utils/payloads.py +++ b/tests/utils/payloads.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +import math import re import time from copy import deepcopy @@ -183,6 +184,14 @@ def validate(self, response: Any, content: str) -> None: "top_logprobs" in item ), "Missing 'top_logprobs' in logprobs content" + # Sanity check: logprob should be valid (not nan/inf/positive) + logprob_val = item["logprob"] + assert not math.isnan(logprob_val), "logprob is NaN" + assert not math.isinf(logprob_val), "logprob is infinite" + assert ( + logprob_val <= 0 + ), f"logprob should be <= 0, got {logprob_val}" + logger.info( f"✓ Logprobs validation passed: found {len(content_logprobs)} tokens with logprobs" ) @@ -281,6 +290,20 @@ def validate(self, response: Any, content: str) -> None: assert len(token_logprobs) == len( tokens ), "Mismatch between token_logprobs and tokens length" + + # Sanity check: each logprob should be valid (not nan/inf/positive) + for i, logprob_val in enumerate(token_logprobs): + if logprob_val is not None: # First token can be None + assert not math.isnan( + logprob_val + ), f"logprob at index {i} is NaN" + assert not math.isinf( + logprob_val + ), f"logprob at index {i} is infinite" + assert ( + logprob_val <= 0 + ), f"logprob at index {i} should be <= 0, got {logprob_val}" + logger.info( f"✓ Logprobs validation passed: found {len(token_logprobs)} tokens with logprobs" )