Skip to content

Commit 430f554

Browse files
committed
Add logprobs support to TRTLLM backend (linted)
Signed-off-by: Elijah Soba <[email protected]>
1 parent 4ca1679 commit 430f554

File tree

4 files changed

+241
-17
lines changed

4 files changed

+241
-17
lines changed

components/src/dynamo/trtllm/request_handlers/handler_base.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,76 @@ def check_error(self, result: dict):
106106
result["finish_reason"] == "stop" or result["finish_reason"] == "error"
107107
)
108108

109+
@staticmethod
110+
def _extract_logprobs(
111+
output, num_output_tokens_so_far: int
112+
) -> tuple[list[float] | None, list[list[dict]] | None]:
113+
"""
114+
Extract logprobs from the TRTLLM output for new tokens.
115+
116+
Args:
117+
output: TRTLLM CompletionOutput object
118+
num_output_tokens_so_far: Number of tokens already processed
119+
Returns:
120+
Tuple of (log_probs, top_logprobs) in Dynamo's expected format:
121+
- log_probs: List of log probabilities for each new token
122+
- top_logprobs: List of top logprobs dicts for each new token
123+
"""
124+
if output.logprobs is None:
125+
return None, None
126+
127+
# Get logprobs for new tokens only
128+
new_logprobs = output.logprobs[num_output_tokens_so_far:]
129+
if not new_logprobs:
130+
return None, None
131+
132+
# From TRTLLM CompletionOutput API, logprobs: (TokenLogprobs | List[float], optional)
133+
# Expect TokenLogprobs output when logprobs is set, check edge case where list[float] is returned instead
134+
if isinstance(new_logprobs[0], float):
135+
return [float(lp) for lp in new_logprobs], None
136+
137+
log_probs = []
138+
top_logprobs = []
139+
140+
for token_idx, token_logprobs_dict in enumerate(new_logprobs):
141+
if token_logprobs_dict is None:
142+
continue
143+
144+
# Get the actual token_id that was generated at this position
145+
actual_token_id = output.token_ids[num_output_tokens_so_far + token_idx]
146+
147+
# Extract log probability for the selected token
148+
if actual_token_id in token_logprobs_dict:
149+
selected_logprob = token_logprobs_dict[actual_token_id]
150+
log_probs.append(float(selected_logprob.logprob))
151+
else:
152+
# Fallback: use the first logprob if selected token not found
153+
first_logprob = next(iter(token_logprobs_dict.values()), None)
154+
if first_logprob:
155+
log_probs.append(float(first_logprob.logprob))
156+
157+
# Build top_logprobs list for this token position
158+
# NOTE: TRTLLM LogProb API doesn't have decoded_token, will default to None
159+
token_top_logprobs = []
160+
for tok_id, logprob_info in token_logprobs_dict.items():
161+
token_top_logprobs.append(
162+
{
163+
"rank": logprob_info.rank
164+
if hasattr(logprob_info, "rank")
165+
else 0,
166+
"token_id": tok_id,
167+
"token": (
168+
logprob_info.decoded_token
169+
if hasattr(logprob_info, "decoded_token")
170+
else None
171+
),
172+
"logprob": float(logprob_info.logprob),
173+
}
174+
)
175+
top_logprobs.append(token_top_logprobs)
176+
177+
return log_probs if log_probs else None, top_logprobs if top_logprobs else None
178+
109179
async def _handle_cancellation(
110180
self, generation_result: GenerationResult, context: Context
111181
):
@@ -236,6 +306,26 @@ async def generate_locally(
236306
if hasattr(sampling_params, key):
237307
setattr(sampling_params, key, value)
238308

309+
# Additional sampling params in output options
310+
output_options = request.get("output_options", {})
311+
if output_options:
312+
logprobs_value = output_options.get("logprobs")
313+
314+
# Handle logprobs
315+
if logprobs_value is not None:
316+
if hasattr(sampling_params, "logprobs"):
317+
setattr(
318+
sampling_params, "logprobs", max(1, int(logprobs_value))
319+
) # If top_logprobs = 0, still want to see chosen token logprob
320+
321+
# Handle prompt_logprobs
322+
prompt_logprobs_value = output_options.get("prompt_logprobs")
323+
if prompt_logprobs_value:
324+
if hasattr(sampling_params, "prompt_logprobs"):
325+
setattr(
326+
sampling_params, "prompt_logprobs", int(prompt_logprobs_value)
327+
)
328+
239329
max_tokens = request["stop_conditions"]["max_tokens"]
240330
if max_tokens:
241331
sampling_params.max_tokens = max_tokens
@@ -302,6 +392,15 @@ async def generate_locally(
302392

303393
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
304394

395+
# Extract logprobs from the output
396+
log_probs, top_logprobs = self._extract_logprobs(
397+
output, num_output_tokens_so_far
398+
)
399+
if log_probs:
400+
out["log_probs"] = log_probs
401+
if top_logprobs:
402+
out["top_logprobs"] = top_logprobs
403+
305404
if output.finish_reason:
306405
out["finish_reason"] = output.finish_reason
307406
if output.stop_reason:

tests/serve/test_trtllm.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
)
1515
from tests.utils.engine_process import EngineConfig
1616
from tests.utils.payload_builder import (
17+
TEXT_PROMPT,
18+
chat_payload,
1719
chat_payload_default,
20+
completion_payload,
1821
completion_payload_default,
1922
metric_payload_default,
2023
multimodal_payload_default,
@@ -91,6 +94,34 @@ class TRTLLMConfig(EngineConfig):
9194
metric_payload_default(port=8082, min_num_requests=6, backend="trtllm"),
9295
],
9396
),
97+
"aggregated_logprobs": TRTLLMConfig(
98+
name="aggregated_logprobs",
99+
directory=trtllm_dir,
100+
script_name="agg.sh",
101+
marks=[pytest.mark.gpu_1, pytest.mark.pre_merge, pytest.mark.trtllm],
102+
model="Qwen/Qwen3-0.6B",
103+
models_port=8000,
104+
request_payloads=[
105+
chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=5),
106+
chat_payload(content=TEXT_PROMPT, logprobs=False, top_logprobs=5),
107+
chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=None),
108+
chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=0),
109+
],
110+
),
111+
"disaggregated_logprobs": TRTLLMConfig(
112+
name="disaggregated_logprobs",
113+
directory=trtllm_dir,
114+
script_name="disagg.sh",
115+
marks=[pytest.mark.gpu_1, pytest.mark.pre_merge, pytest.mark.trtllm],
116+
model="Qwen/Qwen3-0.6B",
117+
models_port=8000,
118+
request_payloads=[
119+
chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=5),
120+
chat_payload(content=TEXT_PROMPT, logprobs=False, top_logprobs=5),
121+
chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=None),
122+
chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=0),
123+
],
124+
),
94125
"aggregated_router": TRTLLMConfig(
95126
name="aggregated_router",
96127
directory=trtllm_dir,
@@ -159,6 +190,7 @@ class TRTLLMConfig(EngineConfig):
159190
},
160191
request_payloads=[
161192
completion_payload_default(),
193+
completion_payload(prompt=TEXT_PROMPT, logprobs=3),
162194
],
163195
),
164196
}

tests/utils/payload_builder.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from tests.utils.client import send_request
77
from tests.utils.payloads import (
88
ChatPayload,
9+
ChatPayloadWithLogprobs,
910
CompletionPayload,
11+
CompletionPayloadWithLogprobs,
1012
EmbeddingPayload,
1113
MetricsPayload,
1214
)
@@ -134,6 +136,8 @@ def chat_payload(
134136
max_tokens: int = 300,
135137
temperature: Optional[float] = None,
136138
stream: bool = False,
139+
logprobs: bool = False,
140+
top_logprobs: Optional[int] = None,
137141
extra_body: Optional[Dict[str, Any]] = None,
138142
) -> ChatPayload:
139143
body: Dict[str, Any] = {
@@ -145,19 +149,31 @@ def chat_payload(
145149
],
146150
"max_tokens": max_tokens,
147151
"stream": stream,
152+
"logprobs": logprobs,
148153
}
149154
if temperature is not None:
150155
body["temperature"] = temperature
151156

157+
if top_logprobs is not None:
158+
body["top_logprobs"] = top_logprobs
159+
152160
if extra_body:
153161
body.update(extra_body)
154162

155-
return ChatPayload(
156-
body=body,
157-
repeat_count=repeat_count,
158-
expected_log=expected_log or [],
159-
expected_response=expected_response or [],
160-
)
163+
if logprobs:
164+
return ChatPayloadWithLogprobs(
165+
body=body,
166+
repeat_count=repeat_count,
167+
expected_log=expected_log or [],
168+
expected_response=expected_response or [],
169+
)
170+
else:
171+
return ChatPayload(
172+
body=body,
173+
repeat_count=repeat_count,
174+
expected_log=expected_log or [],
175+
expected_response=expected_response or [],
176+
)
161177

162178

163179
def completion_payload(
@@ -168,18 +184,29 @@ def completion_payload(
168184
max_tokens: int = 150,
169185
temperature: float = 0.1,
170186
stream: bool = False,
187+
logprobs: Optional[int] = None,
171188
) -> CompletionPayload:
172-
return CompletionPayload(
173-
body={
174-
"prompt": prompt,
175-
"max_tokens": max_tokens,
176-
"temperature": temperature,
177-
"stream": stream,
178-
},
179-
repeat_count=repeat_count,
180-
expected_log=expected_log or [],
181-
expected_response=expected_response or [],
182-
)
189+
body: Dict[str, Any] = {
190+
"prompt": prompt,
191+
"max_tokens": max_tokens,
192+
"temperature": temperature,
193+
"stream": stream,
194+
}
195+
if logprobs is not None:
196+
body["logprobs"] = logprobs
197+
return CompletionPayloadWithLogprobs(
198+
body=body,
199+
repeat_count=repeat_count,
200+
expected_log=expected_log or [],
201+
expected_response=expected_response or [],
202+
)
203+
else:
204+
return CompletionPayload(
205+
body=body,
206+
repeat_count=repeat_count,
207+
expected_log=expected_log or [],
208+
expected_response=expected_response or [],
209+
)
183210

184211

185212
def embedding_payload_default(

tests/utils/payloads.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,39 @@ def response_handler(self, response: Any) -> str:
155155
return ChatPayload.extract_content(response)
156156

157157

158+
@dataclass
159+
class ChatPayloadWithLogprobs(ChatPayload):
160+
"""Chat payload that validates logprobs in response."""
161+
162+
def validate(self, response: Any, content: str) -> None:
163+
"""Validate response contains logprobs fields."""
164+
super().validate(response, content)
165+
166+
result = response.json()
167+
choice = result["choices"][0]
168+
169+
# Validate logprobs field exists
170+
assert "logprobs" in choice, "Missing 'logprobs' in choice"
171+
172+
logprobs_data = choice["logprobs"]
173+
if logprobs_data is not None:
174+
assert "content" in logprobs_data, "Missing 'content' in logprobs"
175+
content_logprobs = logprobs_data["content"]
176+
177+
if content_logprobs:
178+
# Validate structure of logprobs
179+
for item in content_logprobs:
180+
assert "token" in item, "Missing 'token' in logprobs content"
181+
assert "logprob" in item, "Missing 'logprob' in logprobs content"
182+
assert (
183+
"top_logprobs" in item
184+
), "Missing 'top_logprobs' in logprobs content"
185+
186+
logger.info(
187+
f"✓ Logprobs validation passed: found {len(content_logprobs)} tokens with logprobs"
188+
)
189+
190+
158191
@dataclass
159192
class ToolCallingChatPayload(ChatPayload):
160193
"""ChatPayload that validates tool calls in the response."""
@@ -220,6 +253,39 @@ def response_handler(self, response: Any) -> str:
220253
return CompletionPayload.extract_text(response)
221254

222255

256+
@dataclass
257+
class CompletionPayloadWithLogprobs(CompletionPayload):
258+
"""Completion payload that validates logprobs in response."""
259+
260+
def validate(self, response: Any, content: str) -> None:
261+
"""Validate response contains logprobs fields."""
262+
super().validate(response, content)
263+
264+
result = response.json()
265+
choice = result["choices"][0]
266+
267+
# Validate logprobs field exists
268+
assert "logprobs" in choice, "Missing 'logprobs' in choice"
269+
270+
logprobs_data = choice["logprobs"]
271+
if logprobs_data is not None:
272+
assert (
273+
"token_logprobs" in logprobs_data
274+
), "Missing 'token_logprobs' in logprobs"
275+
assert "tokens" in logprobs_data, "Missing 'tokens' in logprobs"
276+
277+
token_logprobs = logprobs_data["token_logprobs"]
278+
tokens = logprobs_data["tokens"]
279+
280+
if token_logprobs:
281+
assert len(token_logprobs) == len(
282+
tokens
283+
), "Mismatch between token_logprobs and tokens length"
284+
logger.info(
285+
f"✓ Logprobs validation passed: found {len(token_logprobs)} tokens with logprobs"
286+
)
287+
288+
223289
@dataclass
224290
class EmbeddingPayload(BasePayload):
225291
"""Payload for embeddings endpoint."""

0 commit comments

Comments
 (0)