Skip to content

Commit 85ff86c

Browse files
committed
feat: Add logprobs support to SGLang backend closes: #4685
Signed-off-by: Aryan Bagade <[email protected]>
1 parent 82577b0 commit 85ff86c

File tree

2 files changed

+139
-0
lines changed

2 files changed

+139
-0
lines changed

components/src/dynamo/sglang/request_handlers/llm/decode_handler.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from dynamo.sglang.publisher import DynamoSglangPublisher
1515
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
1616

17+
logger = logging.getLogger(__name__)
18+
1719

1820
class DecodeWorkerHandler(BaseWorkerHandler):
1921
"""Handler for decode workers in both aggregated and disaggregated serving modes."""
@@ -77,6 +79,7 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
7779
# Token-based request format
7880
sampling_opts = request.get("sampling_options", {})
7981
stop_conditions = request.get("stop_conditions", {})
82+
output_options = request.get("output_options", {})
8083

8184
param_mapping = {
8285
"temperature": sampling_opts.get("temperature"),
@@ -85,6 +88,23 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
8588
"max_new_tokens": stop_conditions.get("max_tokens"),
8689
"ignore_eos": stop_conditions.get("ignore_eos"),
8790
}
91+
92+
# Handle logprobs from output_options
93+
logprobs_value = output_options.get("logprobs")
94+
if logprobs_value is not None and logprobs_value != "":
95+
try:
96+
parsed_logprobs = int(logprobs_value)
97+
if parsed_logprobs < 0:
98+
logger.warning(
99+
f"Invalid logprobs value: {logprobs_value} (must be non-negative), ignoring"
100+
)
101+
else:
102+
param_mapping["return_logprob"] = True
103+
param_mapping["top_logprobs_num"] = parsed_logprobs
104+
except (ValueError, TypeError):
105+
logger.warning(
106+
f"Invalid logprobs value: {logprobs_value} (must be integer), ignoring"
107+
)
88108
else:
89109
# OpenAI request format
90110
param_mapping = {
@@ -94,6 +114,14 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
94114
"max_new_tokens": request.get("max_tokens"),
95115
}
96116

117+
# Handle logprobs from OpenAI format
118+
logprobs = request.get("logprobs")
119+
top_logprobs = request.get("top_logprobs")
120+
if logprobs:
121+
param_mapping["return_logprob"] = True
122+
if top_logprobs is not None:
123+
param_mapping["top_logprobs_num"] = top_logprobs
124+
97125
return {k: v for k, v in param_mapping.items() if v is not None}
98126

99127
async def generate(
@@ -193,6 +221,82 @@ async def generate(
193221
async for out in self._process_text_stream(agg, context):
194222
yield out
195223

224+
@staticmethod
225+
def _extract_logprobs(
226+
res: Dict[str, Any], num_output_tokens_so_far: int
227+
) -> tuple[list[float] | None, list[list[dict]] | None]:
228+
"""
229+
Extract logprobs from SGLang response for new tokens.
230+
231+
Args:
232+
res: SGLang response dict
233+
num_output_tokens_so_far: Number of tokens already processed
234+
235+
Returns:
236+
Tuple of (log_probs, top_logprobs) in Dynamo's expected format:
237+
- log_probs: List of log probabilities for each new token
238+
- top_logprobs: List of top logprobs dicts for each new token
239+
"""
240+
meta_info = res.get("meta_info", {})
241+
242+
# SGLang uses "output_token_logprobs" for selected token logprobs
243+
# Format: [(logprob, token_id, decoded_text), ...] - one tuple per token
244+
output_token_logprobs = meta_info.get("output_token_logprobs")
245+
246+
# SGLang uses "output_top_logprobs" for top-k alternatives
247+
# Format: [[(logprob, token_id, text), ...], ...] - list of lists
248+
output_top_logprobs = meta_info.get("output_top_logprobs")
249+
250+
if not output_token_logprobs:
251+
return None, None
252+
253+
# Get logprobs for new tokens only
254+
new_token_logprobs = output_token_logprobs[num_output_tokens_so_far:]
255+
if not new_token_logprobs:
256+
return None, None
257+
258+
log_probs = []
259+
top_logprobs = []
260+
261+
# Extract selected token logprobs
262+
for token_data in new_token_logprobs:
263+
if token_data is None:
264+
continue
265+
# SGLang format: (logprob, token_id, decoded_text)
266+
logprob_val = token_data[0]
267+
if logprob_val is not None:
268+
log_probs.append(float(logprob_val))
269+
270+
# Extract top logprobs if available
271+
if output_top_logprobs:
272+
new_top_logprobs = output_top_logprobs[num_output_tokens_so_far:]
273+
for token_top_list in new_top_logprobs:
274+
if not token_top_list:
275+
top_logprobs.append([])
276+
continue
277+
278+
token_top_logprobs = []
279+
for rank, alt_data in enumerate(token_top_list):
280+
if alt_data is None:
281+
continue
282+
# SGLang format: (logprob, token_id, decoded_text)
283+
logprob_val = alt_data[0]
284+
token_id = alt_data[1]
285+
decoded_text = alt_data[2] if len(alt_data) > 2 else None
286+
token_top_logprobs.append(
287+
{
288+
"rank": rank,
289+
"token_id": token_id,
290+
"token": decoded_text,
291+
"logprob": (
292+
float(logprob_val) if logprob_val is not None else None
293+
),
294+
}
295+
)
296+
top_logprobs.append(token_top_logprobs)
297+
298+
return log_probs if log_probs else None, top_logprobs if top_logprobs else None
299+
196300
async def _process_token_stream(
197301
self,
198302
stream_source: AsyncGenerator[Dict[str, Any], None],
@@ -239,6 +343,16 @@ async def _process_token_stream(
239343

240344
next_total_toks = len(output_ids)
241345
out["token_ids"] = output_ids[num_output_tokens_so_far:]
346+
347+
# Extract logprobs for new tokens
348+
log_probs, top_logprobs = self._extract_logprobs(
349+
res, num_output_tokens_so_far
350+
)
351+
if log_probs is not None:
352+
out["log_probs"] = log_probs
353+
if top_logprobs is not None:
354+
out["top_logprobs"] = top_logprobs
355+
242356
num_output_tokens_so_far = next_total_toks
243357
if finish_reason:
244358
input_tokens = res["meta_info"]["prompt_tokens"]

tests/serve/test_sglang.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from tests.utils.payload_builder import (
1818
chat_payload,
1919
chat_payload_default,
20+
chat_payload_with_logprobs,
2021
completion_payload_default,
22+
completion_payload_with_logprobs,
2123
embedding_payload,
2224
embedding_payload_default,
2325
metric_payload_default,
@@ -229,6 +231,29 @@ class SGLangConfig(EngineConfig):
229231
completion_payload_default(),
230232
],
231233
),
234+
"aggregated_logprobs": SGLangConfig(
235+
name="aggregated_logprobs",
236+
directory=sglang_dir,
237+
script_name="agg.sh",
238+
marks=[pytest.mark.gpu_1],
239+
model="Qwen/Qwen3-0.6B",
240+
request_payloads=[
241+
chat_payload_with_logprobs(
242+
repeat_count=2,
243+
expected_response=["AI", "knock", "joke"],
244+
max_tokens=30,
245+
temperature=0.0,
246+
top_logprobs=3,
247+
),
248+
completion_payload_with_logprobs(
249+
repeat_count=2,
250+
expected_response=["AI", "knock", "joke"],
251+
max_tokens=30,
252+
temperature=0.0,
253+
logprobs=5,
254+
),
255+
],
256+
),
232257
}
233258

234259

0 commit comments

Comments
 (0)