Skip to content
Open
126 changes: 126 additions & 0 deletions components/src/dynamo/sglang/request_handlers/llm/decode_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler

logger = logging.getLogger(__name__)


class DecodeWorkerHandler(BaseWorkerHandler):
"""Handler for decode workers in both aggregated and disaggregated serving modes."""
Expand Down Expand Up @@ -77,6 +79,7 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
# Token-based request format
sampling_opts = request.get("sampling_options", {})
stop_conditions = request.get("stop_conditions", {})
output_options = request.get("output_options", {})

param_mapping = {
"temperature": sampling_opts.get("temperature"),
Expand All @@ -85,6 +88,23 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
"max_new_tokens": stop_conditions.get("max_tokens"),
"ignore_eos": stop_conditions.get("ignore_eos"),
}

# Handle logprobs from output_options
logprobs_value = output_options.get("logprobs")
if logprobs_value is not None and logprobs_value != "":
try:
parsed_logprobs = int(logprobs_value)
if parsed_logprobs < 0:
logger.warning(
f"Invalid logprobs value: {logprobs_value} (must be non-negative), ignoring"
)
else:
param_mapping["return_logprob"] = True
param_mapping["top_logprobs_num"] = parsed_logprobs
except (ValueError, TypeError):
logger.warning(
f"Invalid logprobs value: {logprobs_value} (must be integer), ignoring"
)
else:
# OpenAI request format
param_mapping = {
Expand All @@ -94,6 +114,14 @@ def _build_sampling_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
"max_new_tokens": request.get("max_tokens"),
}

# Handle logprobs from OpenAI format
logprobs = request.get("logprobs")
top_logprobs = request.get("top_logprobs")
if logprobs:
param_mapping["return_logprob"] = True
if top_logprobs is not None:
param_mapping["top_logprobs_num"] = top_logprobs

return {k: v for k, v in param_mapping.items() if v is not None}

async def generate(
Expand Down Expand Up @@ -193,6 +221,94 @@ async def generate(
async for out in self._process_text_stream(agg, context):
yield out

@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wdyt about moving this into handler base?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had some digging here, basically the extraction logic is backend-specific since SGLang returns logprobs as tuples (logprob, token_id, decoded_text) which differs from TRTLLM's structure.
vLLM also keeps its _extract_logprobs in handlers.py rather than a shared base for the same reason. Happy to refactor if you see a clean abstraction though!

def _extract_logprobs(
res: Dict[str, Any], num_output_tokens_so_far: int
) -> tuple[list[float] | None, list[list[dict]] | None]:
"""
Extract logprobs from SGLang response for new tokens.

Args:
res: SGLang response dict
num_output_tokens_so_far: Number of tokens already processed

Returns:
Tuple of (log_probs, top_logprobs) in Dynamo's expected format:
- log_probs: List of log probabilities for each new token
- top_logprobs: List of top logprobs dicts for each new token
"""
meta_info = res.get("meta_info", {})

# SGLang uses "output_token_logprobs" for selected token logprobs
# Format: [(logprob, token_id, decoded_text), ...] - one tuple per token
output_token_logprobs = meta_info.get("output_token_logprobs")

# SGLang uses "output_top_logprobs" for top-k alternatives
# Format: [[(logprob, token_id, text), ...], ...] - list of lists
output_top_logprobs = meta_info.get("output_top_logprobs")

if not output_token_logprobs:
return None, None

# Get logprobs for new tokens only
new_token_logprobs = output_token_logprobs[num_output_tokens_so_far:]
if not new_token_logprobs:
return None, None

log_probs = []
top_logprobs = []

# Get top logprobs slice if available
new_top_logprobs = None
if output_top_logprobs:
new_top_logprobs = output_top_logprobs[num_output_tokens_so_far:]

# Extract logprobs for each token, maintaining 1:1 alignment
for idx, token_data in enumerate(new_token_logprobs):
# Skip if token_data is None or logprob_val is None
if token_data is None:
continue
# SGLang format: (logprob, token_id, decoded_text)
logprob_val = token_data[0]
if logprob_val is None:
continue

log_probs.append(float(logprob_val))

# Extract corresponding top logprobs for this token position
if new_top_logprobs and idx < len(new_top_logprobs):
token_top_list = new_top_logprobs[idx]
if not token_top_list:
top_logprobs.append([])
else:
# Filter out None entries and sort by logprob descending
# SGLang doesn't guarantee order, so we sort to assign proper ranks
valid_entries = [
alt_data
for alt_data in token_top_list
if alt_data is not None and alt_data[0] is not None
]
# Sort by logprob descending (highest probability first)
valid_entries.sort(key=lambda x: x[0], reverse=True)

token_top_logprobs = []
for rank, alt_data in enumerate(valid_entries):
# SGLang format: (logprob, token_id, decoded_text)
alt_logprob_val = alt_data[0]
token_id = alt_data[1]
decoded_text = alt_data[2] if len(alt_data) > 2 else None
token_top_logprobs.append(
{
"rank": rank,
"token_id": token_id,
"token": decoded_text,
"logprob": float(alt_logprob_val),
}
)
top_logprobs.append(token_top_logprobs)

return log_probs if log_probs else None, top_logprobs if top_logprobs else None

async def _process_token_stream(
self,
stream_source: AsyncGenerator[Dict[str, Any], None],
Expand Down Expand Up @@ -239,6 +355,16 @@ async def _process_token_stream(

next_total_toks = len(output_ids)
out["token_ids"] = output_ids[num_output_tokens_so_far:]

# Extract logprobs for new tokens
log_probs, top_logprobs = self._extract_logprobs(
res, num_output_tokens_so_far
)
if log_probs is not None:
out["log_probs"] = log_probs
if top_logprobs is not None:
out["top_logprobs"] = top_logprobs

num_output_tokens_so_far = next_total_toks
if finish_reason:
input_tokens = res["meta_info"]["prompt_tokens"]
Expand Down
29 changes: 29 additions & 0 deletions tests/serve/test_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from tests.utils.payload_builder import (
chat_payload,
chat_payload_default,
chat_payload_with_logprobs,
completion_payload_default,
completion_payload_with_logprobs,
embedding_payload,
embedding_payload_default,
metric_payload_default,
Expand Down Expand Up @@ -229,6 +231,33 @@ class SGLangConfig(EngineConfig):
completion_payload_default(),
],
),
"aggregated_logprobs": SGLangConfig(
name="aggregated_logprobs",
directory=sglang_dir,
script_name="agg.sh",
marks=[
pytest.mark.gpu_1,
pytest.mark.pre_merge,
pytest.mark.timeout(240), # 3x measured time + download time
],
model="Qwen/Qwen3-0.6B",
request_payloads=[
chat_payload_with_logprobs(
repeat_count=2,
expected_response=["AI", "knock", "joke"],
max_tokens=30,
temperature=0.0,
top_logprobs=3,
),
completion_payload_with_logprobs(
repeat_count=2,
expected_response=["AI", "knock", "joke"],
max_tokens=30,
temperature=0.0,
logprobs=5,
),
],
),
}


Expand Down
Loading