-
Notifications
You must be signed in to change notification settings - Fork 744
feat: Add logprobs support to SGLang backend #4912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
85ff86c
32486d5
246ab54
1cf26f9
0e60cfd
45fb04e
2420fc9
b1da340
80c7d8c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,8 @@ | |
| from dynamo.sglang.publisher import DynamoSglangPublisher | ||
| from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class DecodeWorkerHandler(BaseWorkerHandler): | ||
| """Handler for decode workers in both aggregated and disaggregated serving modes.""" | ||
|
|
@@ -77,6 +79,7 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]: | |
| # Token-based request format | ||
| sampling_opts = request.get("sampling_options", {}) | ||
| stop_conditions = request.get("stop_conditions", {}) | ||
| output_options = request.get("output_options", {}) | ||
|
|
||
| param_mapping = { | ||
| "temperature": sampling_opts.get("temperature"), | ||
|
|
@@ -85,6 +88,23 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]: | |
| "max_new_tokens": stop_conditions.get("max_tokens"), | ||
| "ignore_eos": stop_conditions.get("ignore_eos"), | ||
| } | ||
|
|
||
| # Handle logprobs from output_options | ||
| logprobs_value = output_options.get("logprobs") | ||
| if logprobs_value is not None and logprobs_value != "": | ||
| try: | ||
| parsed_logprobs = int(logprobs_value) | ||
| if parsed_logprobs < 0: | ||
| logger.warning( | ||
| f"Invalid logprobs value: {logprobs_value} (must be non-negative), ignoring" | ||
| ) | ||
| else: | ||
| param_mapping["return_logprob"] = True | ||
| param_mapping["top_logprobs_num"] = parsed_logprobs | ||
| except (ValueError, TypeError): | ||
| logger.warning( | ||
| f"Invalid logprobs value: {logprobs_value} (must be integer), ignoring" | ||
| ) | ||
| else: | ||
| # OpenAI request format | ||
| param_mapping = { | ||
|
|
@@ -94,6 +114,14 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]: | |
| "max_new_tokens": request.get("max_tokens"), | ||
| } | ||
|
|
||
| # Handle logprobs from OpenAI format | ||
| logprobs = request.get("logprobs") | ||
| top_logprobs = request.get("top_logprobs") | ||
| if logprobs: | ||
| param_mapping["return_logprob"] = True | ||
| if top_logprobs is not None: | ||
| param_mapping["top_logprobs_num"] = top_logprobs | ||
|
|
||
| return {k: v for k, v in param_mapping.items() if v is not None} | ||
|
|
||
| async def generate( | ||
|
|
@@ -193,6 +221,94 @@ async def generate( | |
| async for out in self._process_text_stream(agg, context): | ||
| yield out | ||
|
|
||
| @staticmethod | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wdyt about moving this into handler base?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had some digging here, basically the extraction logic is backend-specific since SGLang returns logprobs as tuples (logprob, token_id, decoded_text) which differs from TRTLLM's structure. |
||
| def _extract_logprobs( | ||
| res: Dict[str, Any], num_output_tokens_so_far: int | ||
| ) -> tuple[list[float] | None, list[list[dict]] | None]: | ||
| """ | ||
| Extract logprobs from SGLang response for new tokens. | ||
|
|
||
| Args: | ||
| res: SGLang response dict | ||
| num_output_tokens_so_far: Number of tokens already processed | ||
|
|
||
| Returns: | ||
| Tuple of (log_probs, top_logprobs) in Dynamo's expected format: | ||
| - log_probs: List of log probabilities for each new token | ||
| - top_logprobs: List of top logprobs dicts for each new token | ||
| """ | ||
| meta_info = res.get("meta_info", {}) | ||
|
|
||
| # SGLang uses "output_token_logprobs" for selected token logprobs | ||
| # Format: [(logprob, token_id, decoded_text), ...] - one tuple per token | ||
| output_token_logprobs = meta_info.get("output_token_logprobs") | ||
|
|
||
| # SGLang uses "output_top_logprobs" for top-k alternatives | ||
| # Format: [[(logprob, token_id, text), ...], ...] - list of lists | ||
| output_top_logprobs = meta_info.get("output_top_logprobs") | ||
|
|
||
| if not output_token_logprobs: | ||
| return None, None | ||
|
|
||
| # Get logprobs for new tokens only | ||
| new_token_logprobs = output_token_logprobs[num_output_tokens_so_far:] | ||
| if not new_token_logprobs: | ||
| return None, None | ||
|
|
||
| log_probs = [] | ||
| top_logprobs = [] | ||
|
|
||
| # Get top logprobs slice if available | ||
| new_top_logprobs = None | ||
| if output_top_logprobs: | ||
| new_top_logprobs = output_top_logprobs[num_output_tokens_so_far:] | ||
|
|
||
| # Extract logprobs for each token, maintaining 1:1 alignment | ||
| for idx, token_data in enumerate(new_token_logprobs): | ||
| # Skip if token_data is None or logprob_val is None | ||
| if token_data is None: | ||
| continue | ||
| # SGLang format: (logprob, token_id, decoded_text) | ||
| logprob_val = token_data[0] | ||
| if logprob_val is None: | ||
| continue | ||
|
|
||
| log_probs.append(float(logprob_val)) | ||
|
|
||
| # Extract corresponding top logprobs for this token position | ||
| if new_top_logprobs and idx < len(new_top_logprobs): | ||
| token_top_list = new_top_logprobs[idx] | ||
| if not token_top_list: | ||
| top_logprobs.append([]) | ||
| else: | ||
| # Filter out None entries and sort by logprob descending | ||
| # SGLang doesn't guarantee order, so we sort to assign proper ranks | ||
| valid_entries = [ | ||
| alt_data | ||
| for alt_data in token_top_list | ||
| if alt_data is not None and alt_data[0] is not None | ||
| ] | ||
| # Sort by logprob descending (highest probability first) | ||
| valid_entries.sort(key=lambda x: x[0], reverse=True) | ||
|
|
||
| token_top_logprobs = [] | ||
| for rank, alt_data in enumerate(valid_entries): | ||
| # SGLang format: (logprob, token_id, decoded_text) | ||
| alt_logprob_val = alt_data[0] | ||
| token_id = alt_data[1] | ||
| decoded_text = alt_data[2] if len(alt_data) > 2 else None | ||
| token_top_logprobs.append( | ||
| { | ||
| "rank": rank, | ||
| "token_id": token_id, | ||
| "token": decoded_text, | ||
| "logprob": float(alt_logprob_val), | ||
| } | ||
| ) | ||
| top_logprobs.append(token_top_logprobs) | ||
|
|
||
| return log_probs if log_probs else None, top_logprobs if top_logprobs else None | ||
|
|
||
| async def _process_token_stream( | ||
| self, | ||
| stream_source: AsyncGenerator[Dict[str, Any], None], | ||
|
|
@@ -239,6 +355,16 @@ async def _process_token_stream( | |
|
|
||
| next_total_toks = len(output_ids) | ||
| out["token_ids"] = output_ids[num_output_tokens_so_far:] | ||
|
|
||
| # Extract logprobs for new tokens | ||
| log_probs, top_logprobs = self._extract_logprobs( | ||
| res, num_output_tokens_so_far | ||
| ) | ||
| if log_probs is not None: | ||
| out["log_probs"] = log_probs | ||
| if top_logprobs is not None: | ||
| out["top_logprobs"] = top_logprobs | ||
|
|
||
| num_output_tokens_so_far = next_total_toks | ||
| if finish_reason: | ||
| input_tokens = res["meta_info"]["prompt_tokens"] | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.