From 3566e823b97e6475a010ab1495ee6961937595fa Mon Sep 17 00:00:00 2001 From: Aryan Bagade Date: Tue, 2 Dec 2025 03:12:48 -0800 Subject: [PATCH 1/3] feat: Add logprobs support to vLLM backend (#4683) Implement logprobs functionality for vLLM backend to pass log probability information from vLLM to the frontend. Signed-off-by: Aryan Bagade --- components/src/dynamo/vllm/handlers.py | 91 ++++++++++++++++++++- tests/serve/test_vllm.py | 25 ++++++ tests/utils/payload_builder.py | 105 +++++++++++++++++++++++-- tests/utils/payloads.py | 66 ++++++++++++++++ 4 files changed, 280 insertions(+), 7 deletions(-) diff --git a/components/src/dynamo/vllm/handlers.py b/components/src/dynamo/vllm/handlers.py index 6851aa6ad9..4d671ab053 100644 --- a/components/src/dynamo/vllm/handlers.py +++ b/components/src/dynamo/vllm/handlers.py @@ -73,7 +73,8 @@ def build_sampling_params( Build SamplingParams from a PreprocessedRequest. Args: - request: The PreprocessedRequest dict with 'sampling_options' and 'stop_conditions' + request: The PreprocessedRequest dict with 'sampling_options', 'stop_conditions', + and 'output_options' default_sampling_params: Default sampling parameters to initialize with Returns: @@ -95,6 +96,19 @@ def build_sampling_params( continue setattr(sampling_params, key, value) + # Apply output_options (logprobs, prompt_logprobs, etc.) + output_options = request.get("output_options", {}) + if output_options: + # Handle logprobs - vLLM expects this as an integer or None + logprobs_value = output_options.get("logprobs") + if logprobs_value is not None: + sampling_params.logprobs = int(logprobs_value) + + # Handle prompt_logprobs - vLLM expects this as an integer or None + prompt_logprobs_value = output_options.get("prompt_logprobs") + if prompt_logprobs_value is not None: + sampling_params.prompt_logprobs = int(prompt_logprobs_value) + # If max_tokens wasn't provided (None or missing), compute a dynamic default provided_max_tokens = request.get("stop_conditions", {}).get("max_tokens", None) token_ids = request.get("token_ids", []) @@ -556,6 +570,71 @@ def _build_completion_usage(request_output: RequestOutput) -> Dict[str, Any]: ), } + @staticmethod + def _extract_logprobs( + output, num_output_tokens_so_far: int + ) -> tuple[list[float] | None, list[list[dict]] | None]: + """ + Extract logprobs from vLLM CompletionOutput for new tokens. + + Args: + output: vLLM CompletionOutput object + num_output_tokens_so_far: Number of tokens already processed + + Returns: + Tuple of (log_probs, top_logprobs) in Dynamo's expected format: + - log_probs: List of log probabilities for each new token + - top_logprobs: List of top logprobs dicts for each new token + """ + if output.logprobs is None: + return None, None + + # Get logprobs for new tokens only + new_logprobs = output.logprobs[num_output_tokens_so_far:] + if not new_logprobs: + return None, None + + log_probs = [] + top_logprobs = [] + + for token_idx, token_logprobs_dict in enumerate(new_logprobs): + if token_logprobs_dict is None: + continue + + # Get the actual token_id that was generated at this position + actual_token_id = output.token_ids[num_output_tokens_so_far + token_idx] + + # Extract log probability for the selected token + if actual_token_id in token_logprobs_dict: + selected_logprob = token_logprobs_dict[actual_token_id] + log_probs.append(float(selected_logprob.logprob)) + else: + # Fallback: use the first logprob if selected token not found + first_logprob = next(iter(token_logprobs_dict.values()), None) + if first_logprob: + log_probs.append(float(first_logprob.logprob)) + + # Build top_logprobs list for this token position + token_top_logprobs = [] + for tok_id, logprob_info in token_logprobs_dict.items(): + token_top_logprobs.append( + { + "rank": logprob_info.rank + if hasattr(logprob_info, "rank") + else 0, + "token_id": tok_id, + "token": ( + logprob_info.decoded_token + if hasattr(logprob_info, "decoded_token") + else None + ), + "logprob": float(logprob_info.logprob), + } + ) + top_logprobs.append(token_top_logprobs) + + return log_probs if log_probs else None, top_logprobs if top_logprobs else None + async def generate_tokens( self, prompt, @@ -601,6 +680,16 @@ async def generate_tokens( output = res.outputs[0] next_total_toks = len(output.token_ids) out = {"token_ids": output.token_ids[num_output_tokens_so_far:]} + + # Extract logprobs for new tokens if available + log_probs, top_logprobs = self._extract_logprobs( + output, num_output_tokens_so_far + ) + if log_probs is not None: + out["log_probs"] = log_probs + if top_logprobs is not None: + out["top_logprobs"] = top_logprobs + if output.finish_reason: out["finish_reason"] = output.finish_reason out[ diff --git a/tests/serve/test_vllm.py b/tests/serve/test_vllm.py index c45d283d02..2dbf34dc11 100644 --- a/tests/serve/test_vllm.py +++ b/tests/serve/test_vllm.py @@ -18,7 +18,9 @@ from tests.utils.payload_builder import ( chat_payload, chat_payload_default, + chat_payload_with_logprobs, completion_payload_default, + completion_payload_with_logprobs, metric_payload_default, ) @@ -51,6 +53,29 @@ class VLLMConfig(EngineConfig): metric_payload_default(min_num_requests=6, backend="vllm"), ], ), + "aggregated_logprobs": VLLMConfig( + name="aggregated_logprobs", + directory=vllm_dir, + script_name="agg.sh", + marks=[pytest.mark.gpu_1], + model="Qwen/Qwen3-0.6B", + request_payloads=[ + chat_payload_with_logprobs( + repeat_count=2, + expected_response=["AI", "knock", "joke"], + max_tokens=30, + temperature=0.0, + top_logprobs=3, + ), + completion_payload_with_logprobs( + repeat_count=2, + expected_response=["AI", "knock", "joke"], + max_tokens=30, + temperature=0.0, + logprobs=5, + ), + ], + ), "aggregated_lmcache": VLLMConfig( name="aggregated_lmcache", directory=vllm_dir, diff --git a/tests/utils/payload_builder.py b/tests/utils/payload_builder.py index 1b2e8bf963..df5bef4226 100644 --- a/tests/utils/payload_builder.py +++ b/tests/utils/payload_builder.py @@ -6,7 +6,9 @@ from tests.utils.client import send_request from tests.utils.payloads import ( ChatPayload, + ChatPayloadWithLogprobs, CompletionPayload, + CompletionPayloadWithLogprobs, EmbeddingPayload, MetricsPayload, ) @@ -134,6 +136,8 @@ def chat_payload( max_tokens: int = 300, temperature: Optional[float] = None, stream: bool = False, + logprobs: Optional[int] = None, + top_logprobs: Optional[int] = None, ) -> ChatPayload: body: Dict[str, Any] = { "messages": [ @@ -147,6 +151,10 @@ def chat_payload( } if temperature is not None: body["temperature"] = temperature + if logprobs is not None: + body["logprobs"] = logprobs + if top_logprobs is not None: + body["top_logprobs"] = top_logprobs return ChatPayload( body=body, @@ -164,14 +172,19 @@ def completion_payload( max_tokens: int = 150, temperature: float = 0.1, stream: bool = False, + logprobs: Optional[int] = None, ) -> CompletionPayload: + body: Dict[str, Any] = { + "prompt": prompt, + "max_tokens": max_tokens, + "temperature": temperature, + "stream": stream, + } + if logprobs is not None: + body["logprobs"] = logprobs + return CompletionPayload( - body={ - "prompt": prompt, - "max_tokens": max_tokens, - "temperature": temperature, - "stream": stream, - }, + body=body, repeat_count=repeat_count, expected_log=expected_log or [], expected_response=expected_response or [], @@ -276,3 +289,83 @@ def _check_completions_endpoint(remaining_timeout: float = 30.0) -> bool: return False return _check_completions_endpoint + + +def chat_payload_with_logprobs( + content: Union[str, List[Dict[str, Any]]] = TEXT_PROMPT, + repeat_count: int = 1, + expected_response: Optional[List[str]] = None, + max_tokens: int = 50, + temperature: float = 0.0, + top_logprobs: int = 3, +) -> ChatPayloadWithLogprobs: + """ + Create a chat payload that requests and validates logprobs in the response. + + Args: + content: Message content (text or structured content list) + repeat_count: Number of times to repeat the request + expected_response: List of strings expected in the response text + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + top_logprobs: Number of top logprobs to return per token + + Returns: + ChatPayloadWithLogprobs that validates logprobs in response + """ + body: Dict[str, Any] = { + "messages": [ + { + "role": "user", + "content": content, + } + ], + "max_tokens": max_tokens, + "temperature": temperature, + "logprobs": True, + "top_logprobs": top_logprobs, + } + + return ChatPayloadWithLogprobs( + body=body, + repeat_count=repeat_count, + expected_log=[], + expected_response=expected_response or ["AI", "knock", "joke"], + ) + + +def completion_payload_with_logprobs( + prompt: str = TEXT_PROMPT, + repeat_count: int = 1, + expected_response: Optional[List[str]] = None, + max_tokens: int = 50, + temperature: float = 0.0, + logprobs: int = 5, +) -> CompletionPayloadWithLogprobs: + """ + Create a completion payload that requests and validates logprobs in the response. + + Args: + prompt: Text prompt + repeat_count: Number of times to repeat the request + expected_response: List of strings expected in the response text + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + logprobs: Number of logprobs to return per token + + Returns: + CompletionPayloadWithLogprobs that validates logprobs in response + """ + body: Dict[str, Any] = { + "prompt": prompt, + "max_tokens": max_tokens, + "temperature": temperature, + "logprobs": logprobs, + } + + return CompletionPayloadWithLogprobs( + body=body, + repeat_count=repeat_count, + expected_log=[], + expected_response=expected_response or ["AI", "knock", "joke"], + ) diff --git a/tests/utils/payloads.py b/tests/utils/payloads.py index 3a18dfdf44..bae4ae778e 100644 --- a/tests/utils/payloads.py +++ b/tests/utils/payloads.py @@ -155,6 +155,39 @@ def response_handler(self, response: Any) -> str: return ChatPayload.extract_content(response) +@dataclass +class ChatPayloadWithLogprobs(ChatPayload): + """Chat payload that validates logprobs in response.""" + + def validate(self, response: Any, content: str) -> None: + """Validate response contains logprobs fields.""" + super().validate(response, content) + + result = response.json() + choice = result["choices"][0] + + # Validate logprobs field exists + assert "logprobs" in choice, "Missing 'logprobs' in choice" + + logprobs_data = choice["logprobs"] + if logprobs_data is not None: + assert "content" in logprobs_data, "Missing 'content' in logprobs" + content_logprobs = logprobs_data["content"] + + if content_logprobs: + # Validate structure of logprobs + for item in content_logprobs: + assert "token" in item, "Missing 'token' in logprobs content" + assert "logprob" in item, "Missing 'logprob' in logprobs content" + assert ( + "top_logprobs" in item + ), "Missing 'top_logprobs' in logprobs content" + + logger.info( + f"✓ Logprobs validation passed: found {len(content_logprobs)} tokens with logprobs" + ) + + @dataclass class CompletionPayload(BasePayload): """Payload for completions endpoint.""" @@ -177,6 +210,39 @@ def response_handler(self, response: Any) -> str: return CompletionPayload.extract_text(response) +@dataclass +class CompletionPayloadWithLogprobs(CompletionPayload): + """Completion payload that validates logprobs in response.""" + + def validate(self, response: Any, content: str) -> None: + """Validate response contains logprobs fields.""" + super().validate(response, content) + + result = response.json() + choice = result["choices"][0] + + # Validate logprobs field exists + assert "logprobs" in choice, "Missing 'logprobs' in choice" + + logprobs_data = choice["logprobs"] + if logprobs_data is not None: + assert ( + "token_logprobs" in logprobs_data + ), "Missing 'token_logprobs' in logprobs" + assert "tokens" in logprobs_data, "Missing 'tokens' in logprobs" + + token_logprobs = logprobs_data["token_logprobs"] + tokens = logprobs_data["tokens"] + + if token_logprobs: + assert len(token_logprobs) == len( + tokens + ), "Mismatch between token_logprobs and tokens length" + logger.info( + f"✓ Logprobs validation passed: found {len(token_logprobs)} tokens with logprobs" + ) + + @dataclass class EmbeddingPayload(BasePayload): """Payload for embeddings endpoint.""" From 6f7d85ed61037e6ffa005d9dddff1d11ab9ee600 Mon Sep 17 00:00:00 2001 From: Aryan Bagade Date: Tue, 9 Dec 2025 01:56:57 -0800 Subject: [PATCH 2/3] fix: address review feedback for logprobs implementation Signed-off-by: Aryan Bagade --- components/src/dynamo/vllm/handlers.py | 17 ++++++----------- tests/utils/payloads.py | 26 ++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/components/src/dynamo/vllm/handlers.py b/components/src/dynamo/vllm/handlers.py index af5b95f365..9a5d74577e 100644 --- a/components/src/dynamo/vllm/handlers.py +++ b/components/src/dynamo/vllm/handlers.py @@ -626,23 +626,18 @@ def _extract_logprobs( actual_token_id = output.token_ids[num_output_tokens_so_far + token_idx] # Extract log probability for the selected token - if actual_token_id in token_logprobs_dict: - selected_logprob = token_logprobs_dict[actual_token_id] - log_probs.append(float(selected_logprob.logprob)) - else: - # Fallback: use the first logprob if selected token not found - first_logprob = next(iter(token_logprobs_dict.values()), None) - if first_logprob: - log_probs.append(float(first_logprob.logprob)) + # vLLM guarantees the selected token is always in the logprobs dict + selected_logprob = token_logprobs_dict[actual_token_id] + log_probs.append(float(selected_logprob.logprob)) # Build top_logprobs list for this token position token_top_logprobs = [] for tok_id, logprob_info in token_logprobs_dict.items(): token_top_logprobs.append( { - "rank": logprob_info.rank - if hasattr(logprob_info, "rank") - else 0, + "rank": ( + logprob_info.rank if hasattr(logprob_info, "rank") else 0 + ), "token_id": tok_id, "token": ( logprob_info.decoded_token diff --git a/tests/utils/payloads.py b/tests/utils/payloads.py index 33127fd9b1..f8ad8162fb 100644 --- a/tests/utils/payloads.py +++ b/tests/utils/payloads.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +import math import re import time from copy import deepcopy @@ -183,9 +184,20 @@ def validate(self, response: Any, content: str) -> None: "top_logprobs" in item ), "Missing 'top_logprobs' in logprobs content" + # Sanity check: logprob should be valid (not nan/inf/positive) + logprob_val = item["logprob"] + assert not math.isnan(logprob_val), "logprob is NaN" + assert not math.isinf(logprob_val), "logprob is infinite" + assert ( + logprob_val <= 0 + ), f"logprob should be <= 0, got {logprob_val}" + logger.info( f"✓ Logprobs validation passed: found {len(content_logprobs)} tokens with logprobs" ) + + +@dataclass class ToolCallingChatPayload(ChatPayload): """ChatPayload that validates tool calls in the response.""" @@ -278,6 +290,20 @@ def validate(self, response: Any, content: str) -> None: assert len(token_logprobs) == len( tokens ), "Mismatch between token_logprobs and tokens length" + + # Sanity check: each logprob should be valid (not nan/inf/positive) + for i, logprob_val in enumerate(token_logprobs): + if logprob_val is not None: # First token can be None + assert not math.isnan( + logprob_val + ), f"logprob at index {i} is NaN" + assert not math.isinf( + logprob_val + ), f"logprob at index {i} is infinite" + assert ( + logprob_val <= 0 + ), f"logprob at index {i} should be <= 0, got {logprob_val}" + logger.info( f"✓ Logprobs validation passed: found {len(token_logprobs)} tokens with logprobs" ) From bbb8fb939ef1a9b45bf8a06d176356c31e3b0f07 Mon Sep 17 00:00:00 2001 From: Aryan Bagade Date: Tue, 9 Dec 2025 02:12:40 -0800 Subject: [PATCH 3/3] fix: add input validation for logprobs values Signed-off-by: Aryan Bagade --- components/src/dynamo/vllm/handlers.py | 30 ++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/components/src/dynamo/vllm/handlers.py b/components/src/dynamo/vllm/handlers.py index 9a5d74577e..f1bc298755 100644 --- a/components/src/dynamo/vllm/handlers.py +++ b/components/src/dynamo/vllm/handlers.py @@ -122,13 +122,35 @@ def build_sampling_params( if output_options: # Handle logprobs - vLLM expects this as an integer or None logprobs_value = output_options.get("logprobs") - if logprobs_value is not None: - sampling_params.logprobs = int(logprobs_value) + if logprobs_value is not None and logprobs_value != "": + try: + parsed_logprobs = int(logprobs_value) + if parsed_logprobs < 0: + logger.warning( + f"Invalid logprobs value: {logprobs_value} (must be non-negative), ignoring" + ) + else: + sampling_params.logprobs = parsed_logprobs + except (ValueError, TypeError): + logger.warning( + f"Invalid logprobs value: {logprobs_value} (must be integer), ignoring" + ) # Handle prompt_logprobs - vLLM expects this as an integer or None prompt_logprobs_value = output_options.get("prompt_logprobs") - if prompt_logprobs_value is not None: - sampling_params.prompt_logprobs = int(prompt_logprobs_value) + if prompt_logprobs_value is not None and prompt_logprobs_value != "": + try: + parsed_prompt_logprobs = int(prompt_logprobs_value) + if parsed_prompt_logprobs < 0: + logger.warning( + f"Invalid prompt_logprobs value: {prompt_logprobs_value} (must be non-negative), ignoring" + ) + else: + sampling_params.prompt_logprobs = parsed_prompt_logprobs + except (ValueError, TypeError): + logger.warning( + f"Invalid prompt_logprobs value: {prompt_logprobs_value} (must be integer), ignoring" + ) # If max_tokens wasn't provided (None or missing), compute a dynamic default provided_max_tokens = request.get("stop_conditions", {}).get("max_tokens", None)