@@ -258,42 +258,54 @@ def _extract_logprobs(
258258 log_probs = []
259259 top_logprobs = []
260260
261- # Extract selected token logprobs
262- for token_data in new_token_logprobs :
261+ # Get top logprobs slice if available
262+ new_top_logprobs = None
263+ if output_top_logprobs :
264+ new_top_logprobs = output_top_logprobs [num_output_tokens_so_far :]
265+
266+ # Extract logprobs for each token, maintaining 1:1 alignment
267+ for idx , token_data in enumerate (new_token_logprobs ):
268+ # Skip if token_data is None or logprob_val is None
263269 if token_data is None :
264270 continue
265271 # SGLang format: (logprob, token_id, decoded_text)
266272 logprob_val = token_data [0 ]
267- if logprob_val is not None :
268- log_probs . append ( float ( logprob_val ))
273+ if logprob_val is None :
274+ continue
269275
270- # Extract top logprobs if available
271- if output_top_logprobs :
272- new_top_logprobs = output_top_logprobs [num_output_tokens_so_far :]
273- for token_top_list in new_top_logprobs :
276+ log_probs .append (float (logprob_val ))
277+
278+ # Extract corresponding top logprobs for this token position
279+ if new_top_logprobs and idx < len (new_top_logprobs ):
280+ token_top_list = new_top_logprobs [idx ]
274281 if not token_top_list :
275282 top_logprobs .append ([])
276- continue
277-
278- token_top_logprobs = []
279- for rank , alt_data in enumerate (token_top_list ):
280- if alt_data is None :
281- continue
282- # SGLang format: (logprob, token_id, decoded_text)
283- logprob_val = alt_data [0 ]
284- token_id = alt_data [1 ]
285- decoded_text = alt_data [2 ] if len (alt_data ) > 2 else None
286- token_top_logprobs .append (
287- {
288- "rank" : rank ,
289- "token_id" : token_id ,
290- "token" : decoded_text ,
291- "logprob" : (
292- float (logprob_val ) if logprob_val is not None else None
293- ),
294- }
295- )
296- top_logprobs .append (token_top_logprobs )
283+ else :
284+ # Filter out None entries and sort by logprob descending
285+ # SGLang doesn't guarantee order, so we sort to assign proper ranks
286+ valid_entries = [
287+ alt_data
288+ for alt_data in token_top_list
289+ if alt_data is not None and alt_data [0 ] is not None
290+ ]
291+ # Sort by logprob descending (highest probability first)
292+ valid_entries .sort (key = lambda x : x [0 ], reverse = True )
293+
294+ token_top_logprobs = []
295+ for rank , alt_data in enumerate (valid_entries ):
296+ # SGLang format: (logprob, token_id, decoded_text)
297+ alt_logprob_val = alt_data [0 ]
298+ token_id = alt_data [1 ]
299+ decoded_text = alt_data [2 ] if len (alt_data ) > 2 else None
300+ token_top_logprobs .append (
301+ {
302+ "rank" : rank ,
303+ "token_id" : token_id ,
304+ "token" : decoded_text ,
305+ "logprob" : float (alt_logprob_val ),
306+ }
307+ )
308+ top_logprobs .append (token_top_logprobs )
297309
298310 return log_probs if log_probs else None , top_logprobs if top_logprobs else None
299311
0 commit comments