1414from dynamo .sglang .publisher import DynamoSglangPublisher
1515from dynamo .sglang .request_handlers .handler_base import BaseWorkerHandler
1616
17+ logger = logging .getLogger (__name__ )
18+
1719
1820class DecodeWorkerHandler (BaseWorkerHandler ):
1921 """Handler for decode workers in both aggregated and disaggregated serving modes."""
@@ -77,6 +79,7 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
7779 # Token-based request format
7880 sampling_opts = request .get ("sampling_options" , {})
7981 stop_conditions = request .get ("stop_conditions" , {})
82+ output_options = request .get ("output_options" , {})
8083
8184 param_mapping = {
8285 "temperature" : sampling_opts .get ("temperature" ),
@@ -85,6 +88,23 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
8588 "max_new_tokens" : stop_conditions .get ("max_tokens" ),
8689 "ignore_eos" : stop_conditions .get ("ignore_eos" ),
8790 }
91+
92+ # Handle logprobs from output_options
93+ logprobs_value = output_options .get ("logprobs" )
94+ if logprobs_value is not None and logprobs_value != "" :
95+ try :
96+ parsed_logprobs = int (logprobs_value )
97+ if parsed_logprobs < 0 :
98+ logger .warning (
99+ f"Invalid logprobs value: { logprobs_value } (must be non-negative), ignoring"
100+ )
101+ else :
102+ param_mapping ["return_logprob" ] = True
103+ param_mapping ["top_logprobs_num" ] = parsed_logprobs
104+ except (ValueError , TypeError ):
105+ logger .warning (
106+ f"Invalid logprobs value: { logprobs_value } (must be integer), ignoring"
107+ )
88108 else :
89109 # OpenAI request format
90110 param_mapping = {
@@ -94,6 +114,14 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
94114 "max_new_tokens" : request .get ("max_tokens" ),
95115 }
96116
117+ # Handle logprobs from OpenAI format
118+ logprobs = request .get ("logprobs" )
119+ top_logprobs = request .get ("top_logprobs" )
120+ if logprobs :
121+ param_mapping ["return_logprob" ] = True
122+ if top_logprobs is not None :
123+ param_mapping ["top_logprobs_num" ] = top_logprobs
124+
97125 return {k : v for k , v in param_mapping .items () if v is not None }
98126
99127 async def generate (
@@ -193,6 +221,82 @@ async def generate(
193221 async for out in self ._process_text_stream (agg , context ):
194222 yield out
195223
224+ @staticmethod
225+ def _extract_logprobs (
226+ res : Dict [str , Any ], num_output_tokens_so_far : int
227+ ) -> tuple [list [float ] | None , list [list [dict ]] | None ]:
228+ """
229+ Extract logprobs from SGLang response for new tokens.
230+
231+ Args:
232+ res: SGLang response dict
233+ num_output_tokens_so_far: Number of tokens already processed
234+
235+ Returns:
236+ Tuple of (log_probs, top_logprobs) in Dynamo's expected format:
237+ - log_probs: List of log probabilities for each new token
238+ - top_logprobs: List of top logprobs dicts for each new token
239+ """
240+ meta_info = res .get ("meta_info" , {})
241+
242+ # SGLang uses "output_token_logprobs" for selected token logprobs
243+ # Format: [(logprob, token_id, decoded_text), ...] - one tuple per token
244+ output_token_logprobs = meta_info .get ("output_token_logprobs" )
245+
246+ # SGLang uses "output_top_logprobs" for top-k alternatives
247+ # Format: [[(logprob, token_id, text), ...], ...] - list of lists
248+ output_top_logprobs = meta_info .get ("output_top_logprobs" )
249+
250+ if not output_token_logprobs :
251+ return None , None
252+
253+ # Get logprobs for new tokens only
254+ new_token_logprobs = output_token_logprobs [num_output_tokens_so_far :]
255+ if not new_token_logprobs :
256+ return None , None
257+
258+ log_probs = []
259+ top_logprobs = []
260+
261+ # Extract selected token logprobs
262+ for token_data in new_token_logprobs :
263+ if token_data is None :
264+ continue
265+ # SGLang format: (logprob, token_id, decoded_text)
266+ logprob_val = token_data [0 ]
267+ if logprob_val is not None :
268+ log_probs .append (float (logprob_val ))
269+
270+ # Extract top logprobs if available
271+ if output_top_logprobs :
272+ new_top_logprobs = output_top_logprobs [num_output_tokens_so_far :]
273+ for token_top_list in new_top_logprobs :
274+ if not token_top_list :
275+ top_logprobs .append ([])
276+ continue
277+
278+ token_top_logprobs = []
279+ for rank , alt_data in enumerate (token_top_list ):
280+ if alt_data is None :
281+ continue
282+ # SGLang format: (logprob, token_id, decoded_text)
283+ logprob_val = alt_data [0 ]
284+ token_id = alt_data [1 ]
285+ decoded_text = alt_data [2 ] if len (alt_data ) > 2 else None
286+ token_top_logprobs .append (
287+ {
288+ "rank" : rank ,
289+ "token_id" : token_id ,
290+ "token" : decoded_text ,
291+ "logprob" : (
292+ float (logprob_val ) if logprob_val is not None else None
293+ ),
294+ }
295+ )
296+ top_logprobs .append (token_top_logprobs )
297+
298+ return log_probs if log_probs else None , top_logprobs if top_logprobs else None
299+
196300 async def _process_token_stream (
197301 self ,
198302 stream_source : AsyncGenerator [Dict [str , Any ], None ],
@@ -239,6 +343,16 @@ async def _process_token_stream(
239343
240344 next_total_toks = len (output_ids )
241345 out ["token_ids" ] = output_ids [num_output_tokens_so_far :]
346+
347+ # Extract logprobs for new tokens
348+ log_probs , top_logprobs = self ._extract_logprobs (
349+ res , num_output_tokens_so_far
350+ )
351+ if log_probs is not None :
352+ out ["log_probs" ] = log_probs
353+ if top_logprobs is not None :
354+ out ["top_logprobs" ] = top_logprobs
355+
242356 num_output_tokens_so_far = next_total_toks
243357 if finish_reason :
244358 input_tokens = res ["meta_info" ]["prompt_tokens" ]
0 commit comments