Skip to content

Commit 17f26af

Browse files
ksuma2109UbuntuSuma Kasa
authored
[trtllm] Upgrade TRT-LLM to version 0.21.0rc1 for djl-serving 0.33.0 (#2848)
Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Suma Kasa <[email protected]>
1 parent 1437d7a commit 17f26af

File tree

18 files changed

+1162
-252
lines changed

18 files changed

+1162
-252
lines changed

engines/python/setup/djl_python/async_utils.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,33 @@
1212
# the specific language governing permissions and limitations under the License.
1313
import json
1414
import logging
15-
from typing import AsyncGenerator, Callable, Optional, Union
15+
from typing import AsyncGenerator, Callable, Optional, Union, Any
1616

1717
from djl_python.outputs import Output
1818

1919

20+
class ProcessedRequest:
21+
22+
def __init__(
23+
self,
24+
request: Any,
25+
inference_invoker: Callable,
26+
non_stream_output_formatter: Callable,
27+
stream_output_formatter: Callable,
28+
accumulate_chunks: bool,
29+
include_prompt: bool,
30+
):
31+
self.request = request
32+
self.inference_invoker = inference_invoker
33+
# We need access to both the stream and non-stream output formatters here
34+
# because even with streaming requests, there may be some errors before inference that
35+
# result in a return of ErrorResponse object instead of AsyncGenerator
36+
self.non_stream_output_formatter = non_stream_output_formatter
37+
self.stream_output_formatter = stream_output_formatter
38+
self.accumulate_chunks = accumulate_chunks
39+
self.include_prompt = include_prompt
40+
41+
2042
def create_non_stream_output(data: Union[str, dict],
2143
error: Optional[str] = None,
2244
code: Optional[int] = None) -> Output:
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
6+
# except in compliance with the License. A copy of the License is located at
7+
#
8+
# http://aws.amazon.com/apache2.0/
9+
#
10+
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
11+
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
12+
# the specific language governing permissions and limitations under the License.
13+
import json
14+
from typing import Union, Tuple, List
15+
from tensorrt_llm.serve.openai_protocol import (
16+
ErrorResponse,
17+
ChatCompletionResponse,
18+
CompletionResponse,
19+
CompletionRequest,
20+
CompletionLogProbs,
21+
)
22+
from tensorrt_llm.llmapi.tokenizer import TokenizerBase
23+
from djl_python.async_utils import create_non_stream_output
24+
from djl_python.outputs import Output
25+
26+
27+
def convert_lmi_schema_to_completion_request(
28+
payload: dict, ) -> Tuple[CompletionRequest, bool, bool]:
29+
parameters = payload.get("parameters", {})
30+
31+
completion_dict = {
32+
"prompt": payload.pop("inputs"),
33+
"model": payload.pop("model"),
34+
"max_tokens": parameters.pop("max_new_tokens", 30),
35+
"echo": parameters.pop("return_full_text", False),
36+
"truncate_prompt_tokens": parameters.pop("truncate", None),
37+
"n": parameters.pop("top_n_tokens", 1),
38+
"ignore_eos": parameters.pop("ignore_eos_token", False),
39+
"stream": payload.pop("stream", False),
40+
}
41+
# TRTLLM does not support logprobs in completions API. If provided, rely on TRTLLM validation error
42+
include_details_in_response = False
43+
include_prompt = False
44+
if completion_dict["stream"]:
45+
completion_dict["stream_options"] = {
46+
"include_usage": True,
47+
"continuous_usage_stats": True
48+
}
49+
include_prompt = completion_dict.pop("echo", False)
50+
if parameters.pop("details", False):
51+
include_details_in_response = True
52+
if parameters.pop("decoder_input_details", False):
53+
completion_dict["return_context_logits"] = 1
54+
do_sample = parameters.pop("do_sample", None)
55+
# when do_sample is None, just passthrough sampling params as sampling is dictated by the value of other params
56+
# when do_sample is False, set sampling params such that we disable sampling
57+
if do_sample is not None and not do_sample:
58+
parameters["temperature"] = 0.0
59+
60+
completion_dict.update(parameters)
61+
62+
return CompletionRequest(
63+
**completion_dict), include_details_in_response, include_prompt
64+
65+
66+
def convert_completion_response_to_lmi_schema(
67+
response: CompletionResponse,
68+
request: CompletionRequest = None,
69+
include_details: bool = False,
70+
tokenizer: TokenizerBase = None) -> Output:
71+
primary_choice = response.choices[0]
72+
lmi_response = {"generated_text": primary_choice.text}
73+
if not include_details:
74+
return create_non_stream_output(lmi_response)
75+
details = {
76+
"finish_reason": primary_choice.stop_reason,
77+
"generated_tokens": response.usage.completion_tokens,
78+
"seed": request.seed,
79+
}
80+
lmi_response["details"] = details
81+
output = create_non_stream_output(lmi_response)
82+
return output
83+
84+
85+
def convert_completion_chunk_response_to_lmi_schema(
86+
chunk: str,
87+
include_details: bool = False,
88+
history: List[str] = None,
89+
request: CompletionRequest = None,
90+
include_prompt: bool = False,
91+
tokenizer: TokenizerBase = None,
92+
**_,
93+
) -> Tuple[str, bool, List[str]]:
94+
# TRTLLM returns chunks in string format, and the conversion process to TGI
95+
# currently converts the string to an object, and then the object back to a string.
96+
# It's much easier to work with the object instead of manipulating the string, but inefficient
97+
trimmed_chunk = chunk[6:].strip()
98+
if trimmed_chunk == '[DONE]':
99+
data = ""
100+
return data, True, history
101+
102+
trt_completion_chunk = json.loads(trimmed_chunk)
103+
if "error" in trt_completion_chunk:
104+
return json.dumps(trt_completion_chunk,
105+
ensure_ascii=False), True, history
106+
107+
if len(trt_completion_chunk["choices"]) == 0:
108+
# penultimate chunk
109+
return "", False, history
110+
choice = trt_completion_chunk["choices"][0]
111+
index = choice["index"]
112+
token_text = choice["text"]
113+
history.append(token_text)
114+
finish_reason = choice["finish_reason"]
115+
stop_reason = choice["stop_reason"]
116+
usage = trt_completion_chunk["usage"]
117+
118+
# TODO: TokenId and LogProb here
119+
token = {
120+
"id": None,
121+
"text": token_text,
122+
"logprob": None,
123+
}
124+
tgi_chunk = {
125+
"index": index,
126+
"token": token,
127+
"generated_text": None,
128+
"details": None,
129+
}
130+
generation_finished = finish_reason is not None or stop_reason is not None
131+
if generation_finished:
132+
generated_text = ''.join(history)
133+
if include_prompt:
134+
generated_text = request.prompt + generated_text
135+
tgi_chunk["generated_text"] = generated_text
136+
if include_details:
137+
details = {
138+
"finish_reason": finish_reason or stop_reason,
139+
"seed": request.seed,
140+
"generated_tokens": usage["completion_tokens"] + 1,
141+
"input_length": usage["prompt_tokens"],
142+
}
143+
tgi_chunk["details"] = details
144+
json_str = json.dumps(tgi_chunk, ensure_ascii=False)
145+
return json_str, False, history
146+
147+
148+
def lmi_with_details_non_stream_output_formatter(
149+
response: CompletionResponse,
150+
request: CompletionRequest = None,
151+
tokenizer: TokenizerBase = None,
152+
) -> Output:
153+
return convert_completion_response_to_lmi_schema(response,
154+
include_details=True,
155+
request=request,
156+
tokenizer=tokenizer)
157+
158+
159+
def lmi_non_stream_output_formatter(
160+
response: CompletionResponse,
161+
request: CompletionRequest = None,
162+
tokenizer: TokenizerBase = None,
163+
) -> Output:
164+
return convert_completion_response_to_lmi_schema(response,
165+
include_details=False,
166+
request=request,
167+
tokenizer=tokenizer)
168+
169+
170+
def lmi_with_details_stream_output_formatter(
171+
chunk: str,
172+
**kwargs,
173+
) -> Tuple[str, bool, List[str]]:
174+
return convert_completion_chunk_response_to_lmi_schema(
175+
chunk, include_details=True, **kwargs)
176+
177+
178+
def lmi_stream_output_formatter(
179+
chunk: str,
180+
**kwargs,
181+
) -> Tuple[str, bool, List[str]]:
182+
return convert_completion_chunk_response_to_lmi_schema(chunk, **kwargs)
183+
184+
185+
def trtllm_non_stream_output_formatter(
186+
response: Union[ErrorResponse, ChatCompletionResponse, CompletionResponse],
187+
**_,
188+
) -> Output:
189+
if isinstance(response, ErrorResponse):
190+
return create_non_stream_output("",
191+
error=response.message,
192+
code=response.code)
193+
response_data = response.model_dump_json()
194+
return create_non_stream_output(response_data)
195+
196+
197+
def trtllm_stream_output_formatter(
198+
chunk: str,
199+
**_,
200+
) -> Tuple[str, bool]:
201+
# trtllm returns responses in sse format, 'data: {...}'
202+
trimmed_chunk = chunk[6:].strip()
203+
if trimmed_chunk == '[DONE]':
204+
data = ""
205+
last = True
206+
else:
207+
data = trimmed_chunk
208+
last = False
209+
return data, last

0 commit comments

Comments
 (0)