From 85ff86c52544a66f46df7d82047b426b3b1b132d Mon Sep 17 00:00:00 2001 From: Aryan Bagade Date: Thu, 11 Dec 2025 15:21:06 -0800 Subject: [PATCH 1/4] feat: Add logprobs support to SGLang backend closes: #4685 Signed-off-by: Aryan Bagade --- .../request_handlers/llm/decode_handler.py | 114 ++++++++++++++++++ tests/serve/test_sglang.py | 25 ++++ 2 files changed, 139 insertions(+) diff --git a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py index 47572e2f54..6824e7bd95 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py @@ -14,6 +14,8 @@ from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler +logger = logging.getLogger(__name__) + class DecodeWorkerHandler(BaseWorkerHandler): """Handler for decode workers in both aggregated and disaggregated serving modes.""" @@ -77,6 +79,7 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]: # Token-based request format sampling_opts = request.get("sampling_options", {}) stop_conditions = request.get("stop_conditions", {}) + output_options = request.get("output_options", {}) param_mapping = { "temperature": sampling_opts.get("temperature"), @@ -85,6 +88,23 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]: "max_new_tokens": stop_conditions.get("max_tokens"), "ignore_eos": stop_conditions.get("ignore_eos"), } + + # Handle logprobs from output_options + logprobs_value = output_options.get("logprobs") + if logprobs_value is not None and logprobs_value != "": + try: + parsed_logprobs = int(logprobs_value) + if parsed_logprobs < 0: + logger.warning( + f"Invalid logprobs value: {logprobs_value} (must be non-negative), ignoring" + ) + else: + param_mapping["return_logprob"] = True + param_mapping["top_logprobs_num"] = parsed_logprobs + except (ValueError, TypeError): + logger.warning( + f"Invalid logprobs value: {logprobs_value} (must be integer), ignoring" + ) else: # OpenAI request format param_mapping = { @@ -94,6 +114,14 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]: "max_new_tokens": request.get("max_tokens"), } + # Handle logprobs from OpenAI format + logprobs = request.get("logprobs") + top_logprobs = request.get("top_logprobs") + if logprobs: + param_mapping["return_logprob"] = True + if top_logprobs is not None: + param_mapping["top_logprobs_num"] = top_logprobs + return {k: v for k, v in param_mapping.items() if v is not None} async def generate( @@ -193,6 +221,82 @@ async def generate( async for out in self._process_text_stream(agg, context): yield out + @staticmethod + def _extract_logprobs( + res: Dict[str, Any], num_output_tokens_so_far: int + ) -> tuple[list[float] | None, list[list[dict]] | None]: + """ + Extract logprobs from SGLang response for new tokens. + + Args: + res: SGLang response dict + num_output_tokens_so_far: Number of tokens already processed + + Returns: + Tuple of (log_probs, top_logprobs) in Dynamo's expected format: + - log_probs: List of log probabilities for each new token + - top_logprobs: List of top logprobs dicts for each new token + """ + meta_info = res.get("meta_info", {}) + + # SGLang uses "output_token_logprobs" for selected token logprobs + # Format: [(logprob, token_id, decoded_text), ...] - one tuple per token + output_token_logprobs = meta_info.get("output_token_logprobs") + + # SGLang uses "output_top_logprobs" for top-k alternatives + # Format: [[(logprob, token_id, text), ...], ...] - list of lists + output_top_logprobs = meta_info.get("output_top_logprobs") + + if not output_token_logprobs: + return None, None + + # Get logprobs for new tokens only + new_token_logprobs = output_token_logprobs[num_output_tokens_so_far:] + if not new_token_logprobs: + return None, None + + log_probs = [] + top_logprobs = [] + + # Extract selected token logprobs + for token_data in new_token_logprobs: + if token_data is None: + continue + # SGLang format: (logprob, token_id, decoded_text) + logprob_val = token_data[0] + if logprob_val is not None: + log_probs.append(float(logprob_val)) + + # Extract top logprobs if available + if output_top_logprobs: + new_top_logprobs = output_top_logprobs[num_output_tokens_so_far:] + for token_top_list in new_top_logprobs: + if not token_top_list: + top_logprobs.append([]) + continue + + token_top_logprobs = [] + for rank, alt_data in enumerate(token_top_list): + if alt_data is None: + continue + # SGLang format: (logprob, token_id, decoded_text) + logprob_val = alt_data[0] + token_id = alt_data[1] + decoded_text = alt_data[2] if len(alt_data) > 2 else None + token_top_logprobs.append( + { + "rank": rank, + "token_id": token_id, + "token": decoded_text, + "logprob": ( + float(logprob_val) if logprob_val is not None else None + ), + } + ) + top_logprobs.append(token_top_logprobs) + + return log_probs if log_probs else None, top_logprobs if top_logprobs else None + async def _process_token_stream( self, stream_source: AsyncGenerator[Dict[str, Any], None], @@ -239,6 +343,16 @@ async def _process_token_stream( next_total_toks = len(output_ids) out["token_ids"] = output_ids[num_output_tokens_so_far:] + + # Extract logprobs for new tokens + log_probs, top_logprobs = self._extract_logprobs( + res, num_output_tokens_so_far + ) + if log_probs is not None: + out["log_probs"] = log_probs + if top_logprobs is not None: + out["top_logprobs"] = top_logprobs + num_output_tokens_so_far = next_total_toks if finish_reason: input_tokens = res["meta_info"]["prompt_tokens"] diff --git a/tests/serve/test_sglang.py b/tests/serve/test_sglang.py index 9591d13571..68ca384703 100644 --- a/tests/serve/test_sglang.py +++ b/tests/serve/test_sglang.py @@ -17,7 +17,9 @@ from tests.utils.payload_builder import ( chat_payload, chat_payload_default, + chat_payload_with_logprobs, completion_payload_default, + completion_payload_with_logprobs, embedding_payload, embedding_payload_default, metric_payload_default, @@ -229,6 +231,29 @@ class SGLangConfig(EngineConfig): completion_payload_default(), ], ), + "aggregated_logprobs": SGLangConfig( + name="aggregated_logprobs", + directory=sglang_dir, + script_name="agg.sh", + marks=[pytest.mark.gpu_1], + model="Qwen/Qwen3-0.6B", + request_payloads=[ + chat_payload_with_logprobs( + repeat_count=2, + expected_response=["AI", "knock", "joke"], + max_tokens=30, + temperature=0.0, + top_logprobs=3, + ), + completion_payload_with_logprobs( + repeat_count=2, + expected_response=["AI", "knock", "joke"], + max_tokens=30, + temperature=0.0, + logprobs=5, + ), + ], + ), } From 246ab544a1f9bf6573d826265d1153f556b9e641 Mon Sep 17 00:00:00 2001 From: Aryan Bagade Date: Thu, 11 Dec 2025 16:14:31 -0800 Subject: [PATCH 2/4] fix: address CodeRabbit review feedback for SGLang logprobs Signed-off-by: Aryan Bagade --- .../request_handlers/llm/decode_handler.py | 70 +++++++++++-------- tests/serve/test_sglang.py | 6 +- 2 files changed, 46 insertions(+), 30 deletions(-) diff --git a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py index 6824e7bd95..dfc42adcf3 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py @@ -258,42 +258,54 @@ def _extract_logprobs( log_probs = [] top_logprobs = [] - # Extract selected token logprobs - for token_data in new_token_logprobs: + # Get top logprobs slice if available + new_top_logprobs = None + if output_top_logprobs: + new_top_logprobs = output_top_logprobs[num_output_tokens_so_far:] + + # Extract logprobs for each token, maintaining 1:1 alignment + for idx, token_data in enumerate(new_token_logprobs): + # Skip if token_data is None or logprob_val is None if token_data is None: continue # SGLang format: (logprob, token_id, decoded_text) logprob_val = token_data[0] - if logprob_val is not None: - log_probs.append(float(logprob_val)) + if logprob_val is None: + continue - # Extract top logprobs if available - if output_top_logprobs: - new_top_logprobs = output_top_logprobs[num_output_tokens_so_far:] - for token_top_list in new_top_logprobs: + log_probs.append(float(logprob_val)) + + # Extract corresponding top logprobs for this token position + if new_top_logprobs and idx < len(new_top_logprobs): + token_top_list = new_top_logprobs[idx] if not token_top_list: top_logprobs.append([]) - continue - - token_top_logprobs = [] - for rank, alt_data in enumerate(token_top_list): - if alt_data is None: - continue - # SGLang format: (logprob, token_id, decoded_text) - logprob_val = alt_data[0] - token_id = alt_data[1] - decoded_text = alt_data[2] if len(alt_data) > 2 else None - token_top_logprobs.append( - { - "rank": rank, - "token_id": token_id, - "token": decoded_text, - "logprob": ( - float(logprob_val) if logprob_val is not None else None - ), - } - ) - top_logprobs.append(token_top_logprobs) + else: + # Filter out None entries and sort by logprob descending + # SGLang doesn't guarantee order, so we sort to assign proper ranks + valid_entries = [ + alt_data + for alt_data in token_top_list + if alt_data is not None and alt_data[0] is not None + ] + # Sort by logprob descending (highest probability first) + valid_entries.sort(key=lambda x: x[0], reverse=True) + + token_top_logprobs = [] + for rank, alt_data in enumerate(valid_entries): + # SGLang format: (logprob, token_id, decoded_text) + alt_logprob_val = alt_data[0] + token_id = alt_data[1] + decoded_text = alt_data[2] if len(alt_data) > 2 else None + token_top_logprobs.append( + { + "rank": rank, + "token_id": token_id, + "token": decoded_text, + "logprob": float(alt_logprob_val), + } + ) + top_logprobs.append(token_top_logprobs) return log_probs if log_probs else None, top_logprobs if top_logprobs else None diff --git a/tests/serve/test_sglang.py b/tests/serve/test_sglang.py index 68ca384703..ec882e79a3 100644 --- a/tests/serve/test_sglang.py +++ b/tests/serve/test_sglang.py @@ -235,7 +235,11 @@ class SGLangConfig(EngineConfig): name="aggregated_logprobs", directory=sglang_dir, script_name="agg.sh", - marks=[pytest.mark.gpu_1], + marks=[ + pytest.mark.gpu_1, + pytest.mark.pre_merge, + pytest.mark.timeout(240), # 3x measured time + download time + ], model="Qwen/Qwen3-0.6B", request_payloads=[ chat_payload_with_logprobs( From 0e60cfdd20c3b9202e9a3d10ec92184883e5cb63 Mon Sep 17 00:00:00 2001 From: Aryan Bagade Date: Thu, 11 Dec 2025 19:31:27 -0800 Subject: [PATCH 3/4] fix: use logging module directly per review feedback Signed-off-by: Aryan Bagade --- .../dynamo/sglang/request_handlers/llm/decode_handler.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py index dfc42adcf3..9db7729bb8 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py @@ -14,8 +14,6 @@ from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler -logger = logging.getLogger(__name__) - class DecodeWorkerHandler(BaseWorkerHandler): """Handler for decode workers in both aggregated and disaggregated serving modes.""" @@ -95,14 +93,14 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]: try: parsed_logprobs = int(logprobs_value) if parsed_logprobs < 0: - logger.warning( + logging.warning( f"Invalid logprobs value: {logprobs_value} (must be non-negative), ignoring" ) else: param_mapping["return_logprob"] = True param_mapping["top_logprobs_num"] = parsed_logprobs except (ValueError, TypeError): - logger.warning( + logging.warning( f"Invalid logprobs value: {logprobs_value} (must be integer), ignoring" ) else: From 2420fc97bc72cbe6ca6722654a8d12f09c6aff4a Mon Sep 17 00:00:00 2001 From: Aryan Bagade Date: Fri, 12 Dec 2025 19:13:31 -0800 Subject: [PATCH 4/4] fix: pass logprobs params to async_generate not SamplingParams SGLang's SamplingParams doesn't accept return_logprob or top_logprobs_num. These must be passed as separate kwargs to engine.async_generate(). Signed-off-by: Aryan Bagade --- .../dynamo/sglang/request_handlers/llm/decode_handler.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py index 9db7729bb8..cce666a733 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py @@ -142,6 +142,10 @@ async def generate( sampling_params = self._build_sampling_params(request) input_param = self._get_input_param(request) + # Extract logprobs params (they go to async_generate, not SamplingParams) + return_logprob = sampling_params.pop("return_logprob", False) + top_logprobs_num = sampling_params.pop("top_logprobs_num", None) + if self.serving_mode == DisaggregationMode.DECODE: # request the bootstrap info from the target prefill worker if ( @@ -190,6 +194,8 @@ async def generate( **input_param, sampling_params=sampling_params, stream=True, + return_logprob=return_logprob, + top_logprobs_num=top_logprobs_num, bootstrap_host=bootstrap_info["bootstrap_host"], bootstrap_port=bootstrap_info["bootstrap_port"], bootstrap_room=bootstrap_info["bootstrap_room"], @@ -210,6 +216,8 @@ async def generate( **input_param, sampling_params=sampling_params, stream=True, + return_logprob=return_logprob, + top_logprobs_num=top_logprobs_num, rid=trace_id, ) if self.skip_tokenizer_init: