Skip to content

Commit e6b1c9d

Browse files
committed
fix: address CodeRabbit review feedback for SGLang logprobs
1 parent 32486d5 commit e6b1c9d

File tree

2 files changed

+46
-30
lines changed

2 files changed

+46
-30
lines changed

components/src/dynamo/sglang/request_handlers/llm/decode_handler.py

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -258,42 +258,54 @@ def _extract_logprobs(
258258
log_probs = []
259259
top_logprobs = []
260260

261-
# Extract selected token logprobs
262-
for token_data in new_token_logprobs:
261+
# Get top logprobs slice if available
262+
new_top_logprobs = None
263+
if output_top_logprobs:
264+
new_top_logprobs = output_top_logprobs[num_output_tokens_so_far:]
265+
266+
# Extract logprobs for each token, maintaining 1:1 alignment
267+
for idx, token_data in enumerate(new_token_logprobs):
268+
# Skip if token_data is None or logprob_val is None
263269
if token_data is None:
264270
continue
265271
# SGLang format: (logprob, token_id, decoded_text)
266272
logprob_val = token_data[0]
267-
if logprob_val is not None:
268-
log_probs.append(float(logprob_val))
273+
if logprob_val is None:
274+
continue
269275

270-
# Extract top logprobs if available
271-
if output_top_logprobs:
272-
new_top_logprobs = output_top_logprobs[num_output_tokens_so_far:]
273-
for token_top_list in new_top_logprobs:
276+
log_probs.append(float(logprob_val))
277+
278+
# Extract corresponding top logprobs for this token position
279+
if new_top_logprobs and idx < len(new_top_logprobs):
280+
token_top_list = new_top_logprobs[idx]
274281
if not token_top_list:
275282
top_logprobs.append([])
276-
continue
277-
278-
token_top_logprobs = []
279-
for rank, alt_data in enumerate(token_top_list):
280-
if alt_data is None:
281-
continue
282-
# SGLang format: (logprob, token_id, decoded_text)
283-
logprob_val = alt_data[0]
284-
token_id = alt_data[1]
285-
decoded_text = alt_data[2] if len(alt_data) > 2 else None
286-
token_top_logprobs.append(
287-
{
288-
"rank": rank,
289-
"token_id": token_id,
290-
"token": decoded_text,
291-
"logprob": (
292-
float(logprob_val) if logprob_val is not None else None
293-
),
294-
}
295-
)
296-
top_logprobs.append(token_top_logprobs)
283+
else:
284+
# Filter out None entries and sort by logprob descending
285+
# SGLang doesn't guarantee order, so we sort to assign proper ranks
286+
valid_entries = [
287+
alt_data
288+
for alt_data in token_top_list
289+
if alt_data is not None and alt_data[0] is not None
290+
]
291+
# Sort by logprob descending (highest probability first)
292+
valid_entries.sort(key=lambda x: x[0], reverse=True)
293+
294+
token_top_logprobs = []
295+
for rank, alt_data in enumerate(valid_entries):
296+
# SGLang format: (logprob, token_id, decoded_text)
297+
alt_logprob_val = alt_data[0]
298+
token_id = alt_data[1]
299+
decoded_text = alt_data[2] if len(alt_data) > 2 else None
300+
token_top_logprobs.append(
301+
{
302+
"rank": rank,
303+
"token_id": token_id,
304+
"token": decoded_text,
305+
"logprob": float(alt_logprob_val),
306+
}
307+
)
308+
top_logprobs.append(token_top_logprobs)
297309

298310
return log_probs if log_probs else None, top_logprobs if top_logprobs else None
299311

tests/serve/test_sglang.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,11 @@ class SGLangConfig(EngineConfig):
235235
name="aggregated_logprobs",
236236
directory=sglang_dir,
237237
script_name="agg.sh",
238-
marks=[pytest.mark.gpu_1],
238+
marks=[
239+
pytest.mark.gpu_1,
240+
pytest.mark.pre_merge,
241+
pytest.mark.timeout(240), # 3x measured time + download time
242+
],
239243
model="Qwen/Qwen3-0.6B",
240244
request_payloads=[
241245
chat_payload_with_logprobs(

0 commit comments

Comments
 (0)