Skip to content

Commit 2e055fa

Browse files
feat: Switch to TRT-LLM LLM (High Level) API over trtllm-build CLI workflow (#87)
Co-authored-by: KrishnanPrash <[email protected]>
1 parent 6603dd7 commit 2e055fa

File tree

14 files changed

+231
-1533
lines changed

14 files changed

+231
-1533
lines changed

README.md

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ and running the CLI from within the latest corresponding `tritonserver`
2222
container image, which should have all necessary system dependencies installed.
2323

2424
For vLLM and TRT-LLM, you can use their respective images:
25-
- `nvcr.io/nvidia/tritonserver:24.08-vllm-python-py3`
26-
- `nvcr.io/nvidia/tritonserver:24.08-trtllm-python-py3`
25+
- `nvcr.io/nvidia/tritonserver:24.09-vllm-python-py3`
26+
- `nvcr.io/nvidia/tritonserver:24.09-trtllm-python-py3`
2727

2828
If you decide to run the CLI on the host or in a custom image, please
2929
see this list of [additional dependencies](#additional-dependencies-for-custom-environments)
@@ -38,13 +38,14 @@ matrix below:
3838

3939
| Triton CLI Version | TRT-LLM Version | Triton Container Tag |
4040
|:------------------:|:---------------:|:--------------------:|
41+
| 0.1.0 | v0.13.0 | 24.09 |
4142
| 0.0.11 | v0.12.0 | 24.08 |
4243
| 0.0.10 | v0.11.0 | 24.07 |
43-
| 0.0.9 | v0.10.0 | 24.06 |
44-
| 0.0.8 | v0.9.0 | 24.05 |
45-
| 0.0.7 | v0.9.0 | 24.04 |
46-
| 0.0.6 | v0.8.0 | 24.02, 24.03 |
47-
| 0.0.5 | v0.7.1 | 24.01 |
44+
| 0.0.9 | v0.10.0 | 24.06 |
45+
| 0.0.8 | v0.9.0 | 24.05 |
46+
| 0.0.7 | v0.9.0 | 24.04 |
47+
| 0.0.6 | v0.8.0 | 24.02, 24.03 |
48+
| 0.0.5 | v0.7.1 | 24.01 |
4849

4950
### Install from GitHub
5051

@@ -58,7 +59,7 @@ It is also possible to install from a specific branch name, a commit hash
5859
or a tag name. For example to install `triton_cli` with a specific tag:
5960

6061
```bash
61-
GIT_REF="0.0.11"
62+
GIT_REF="0.1.0"
6263
pip install git+https://github.com/triton-inference-server/triton_cli.git@${GIT_REF}
6364
```
6465

@@ -93,7 +94,7 @@ triton -h
9394
triton import -m gpt2
9495

9596
# Start server pointing at the default model repository
96-
triton start --image nvcr.io/nvidia/tritonserver:24.08-vllm-python-py3
97+
triton start --image nvcr.io/nvidia/tritonserver:24.09-vllm-python-py3
9798

9899
# Infer with CLI
99100
triton infer -m gpt2 --prompt "machine learning is"
@@ -119,26 +120,50 @@ minutes.
119120
> in Huggingface through either `huggingface-cli login` or setting the `HF_TOKEN`
120121
> environment variable.
121122
123+
### Model Sources
122124

123-
### Serving a vLLM Model
125+
<!-- TODO: Add more docs on commands, such as a doc on `import` behavior/args -->
124126

125-
vLLM models will be downloaded at runtime when starting the server if not found
126-
locally in the HuggingFace cache. No offline engine building step is required,
127-
but you can pre-download the model in advance to avoid downloading at server
128-
startup time.
127+
The `triton import` command helps automate the process of creating a model repository
128+
to serve with Triton Inference Server. When preparing models, a `--source` is required
129+
to point at the location containing a model/weights. This argument is overloaded to support
130+
a few types of locations:
131+
- HuggingFace (`--source hf:<HUGGINGFACE_ID>`)
132+
- Local Filesystem (`--source local:</path/to/model>`)
133+
134+
#### Model Source Aliases
129135

130-
The following models have currently been tested for vLLM through the CLI:
136+
<!-- TODO: Put known model sources into a JSON file or something separate from the code -->
137+
138+
For convenience, the Triton CLI supports short aliases for a handful
139+
of models which will automatically set the correct `--source` for you.
140+
A full list of aliases can be found from `KNOWN_MODEL_SOURCES` within `parser.py`,
141+
but some examples can be found below:
131142
- `gpt2`
132143
- `opt125m`
133144
- `mistral-7b`
134-
- `falcon-7b`
135-
- `llama-2-7b`
136145
- `llama-2-7b-chat`
137-
- `llama-3-8b`
138146
- `llama-3-8b-instruct`
139-
- `llama-3.1-8b`
140147
- `llama-3.1-8b-instruct`
141148

149+
For example, this command will go get Llama 3.1 8B Instruct from HuggingFace:
150+
```bash
151+
triton import -m llama-3.1-8b-instruct
152+
153+
# Equivalent command without alias:
154+
# triton import --model llama-3.1-8b-instruct --source "hf:meta-llama/Llama-3.1-8B-Instruct"
155+
```
156+
157+
For full control and flexibility, you can always manually specify the `--source`.
158+
159+
### Serving a vLLM Model
160+
161+
vLLM models will be downloaded at runtime when starting the server if not found
162+
locally in the HuggingFace cache. No offline engine building step is required,
163+
but you can pre-download the model in advance to avoid downloading at server
164+
startup time.
165+
166+
The following models are supported by vLLM: https://docs.vllm.ai/en/latest/models/supported_models.html
142167

143168
#### Example
144169

@@ -149,10 +174,10 @@ docker run -ti \
149174
--shm-size=1g --ulimit memlock=-1 \
150175
-v ${HOME}/models:/root/models \
151176
-v ${HOME}/.cache/huggingface:/root/.cache/huggingface \
152-
nvcr.io/nvidia/tritonserver:24.08-vllm-python-py3
177+
nvcr.io/nvidia/tritonserver:24.09-vllm-python-py3
153178

154179
# Install the Triton CLI
155-
pip install git+https://github.com/triton-inference-server/triton_cli.git@0.0.11
180+
pip install git+https://github.com/triton-inference-server/triton_cli.git@0.1.0
156181

157182
# Authenticate with huggingface for restricted models like Llama-2 and Llama-3
158183
huggingface-cli login
@@ -189,15 +214,7 @@ triton profile -m llama-3-8b-instruct --backend vllm
189214
> see [here](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_configuration.html#instance-groups).
190215
191216
The following models are currently supported for automating TRT-LLM
192-
engine builds through the CLI:
193-
- `gpt2`
194-
- `opt125m`
195-
- `llama-2-7b`
196-
- `llama-2-7b-chat`
197-
- `llama-3-8b`
198-
- `llama-3-8b-instruct`
199-
- `llama-3.1-8b`
200-
- `llama-3.1-8b-instruct`
217+
engine builds through the CLI: https://nvidia.github.io/TensorRT-LLM/llm-api-examples/index.html#supported-models
201218

202219
> [!NOTE]
203220
> 1. Building a TRT-LLM engine for Llama-2-7B, Llama-3-8B, or Llama-3.1-8B
@@ -222,10 +239,10 @@ docker run -ti \
222239
-v /tmp:/tmp \
223240
-v ${HOME}/models:/root/models \
224241
-v ${HOME}/.cache/huggingface:/root/.cache/huggingface \
225-
nvcr.io/nvidia/tritonserver:24.08-trtllm-python-py3
242+
nvcr.io/nvidia/tritonserver:24.09-trtllm-python-py3
226243

227244
# Install the Triton CLI
228-
pip install git+https://github.com/triton-inference-server/triton_cli.git@0.0.11
245+
pip install git+https://github.com/triton-inference-server/triton_cli.git@0.1.0
229246

230247
# Authenticate with huggingface for restricted models like Llama-2 and Llama-3
231248
huggingface-cli login
@@ -282,5 +299,3 @@ and may not be as optimized as possible for your system or use case.
282299
- Triton CLI currently uses the TRT-LLM dependencies installed in its environment
283300
to build TRT-LLM engines, so you must take care to match the build-time and
284301
run-time versions of TRT-LLM.
285-
- Triton CLI currently does not support launching the server as a background
286-
process.

pyproject.toml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,21 @@ keywords = []
4747
requires-python = ">=3.10,<4"
4848
# TODO: Add [gpu] set of dependencies for trtllm once it's available on pypi
4949
dependencies = [
50-
"grpcio>=1.65.5",
50+
# Client deps - generally versioned together
51+
"grpcio>=1.66.1",
52+
# Use explicit client version matching genai-perf version for tagged release
53+
"tritonclient[all] == 2.50",
54+
"genai-perf @ git+https://github.com/triton-inference-server/[email protected]#subdirectory=genai-perf",
55+
# Misc deps
5156
"directory-tree == 0.0.4", # may remove in future
5257
"docker == 6.1.3",
53-
"genai-perf @ git+https://github.com/triton-inference-server/[email protected]#subdirectory=genai-perf",
5458
# TODO: rely on tritonclient to pull in protobuf and numpy dependencies?
5559
"numpy >=1.21,<2",
5660
"protobuf>=3.7.0",
5761
"prometheus-client == 0.19.0",
5862
"psutil >= 5.9.5", # may remove later
5963
"rich == 13.5.2",
6064
# TODO: Test on cpu-only machine if [cuda] dependency is an issue,
61-
# Use explicit client version matching genai-perf version for tagged release
62-
"tritonclient[all] == 2.49",
6365
"huggingface-hub >= 0.19.4",
6466
# Testing
6567
"pytest >= 8.1.1", # may remove later

src/triton_cli/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@
2424
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

27-
__version__ = "0.0.11"
27+
__version__ = "0.1.0"

src/triton_cli/docker/Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# TRT-LLM image contains engine building and runtime dependencies
2-
FROM nvcr.io/nvidia/tritonserver:24.08-trtllm-python-py3
2+
FROM nvcr.io/nvidia/tritonserver:24.09-trtllm-python-py3
33

44
# Setup vLLM Triton backend
55
RUN mkdir -p /opt/tritonserver/backends/vllm && \
6-
git clone -b r24.08 https://github.com/triton-inference-server/vllm_backend.git /tmp/vllm_backend && \
6+
git clone -b r24.09 https://github.com/triton-inference-server/vllm_backend.git /tmp/vllm_backend && \
77
cp -r /tmp/vllm_backend/src/* /opt/tritonserver/backends/vllm && \
88
rm -r /tmp/vllm_backend
99

src/triton_cli/parser.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,7 @@ def parse_args_repo(parser):
228228
"--source",
229229
type=str,
230230
required=False,
231-
help="Local model path or model identifier. Use prefix 'hf:' to specify a HuggingFace model ID. "
232-
"NOTE: HuggingFace model support is currently limited to Transformer models through the vLLM backend.",
231+
help="Local model path or model identifier. Use prefix 'hf:' to specify a HuggingFace model ID, or 'local:' prefix to specify a file path to a model.",
233232
)
234233

235234
repo_remove = parser.add_parser("remove", help="Remove model from model repository")
@@ -305,7 +304,13 @@ def start_server_with_fallback(args: argparse.Namespace, blocking=True):
305304
try:
306305
args.mode = mode
307306
server = start_server(args, blocking=blocking)
307+
# TODO: Clean up re-entrant print error
308+
except RuntimeError as e:
309+
print(e)
310+
break
308311
except Exception as e:
312+
print(e)
313+
print(type(e))
309314
msg = f"Failed to start server in '{mode}' mode. {e}"
310315
logger.debug(msg)
311316
errors.append(msg)

src/triton_cli/repository.py

Lines changed: 34 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
import shutil
3030
import logging
3131
import subprocess
32+
import multiprocessing
3233
from pathlib import Path
33-
from rich.console import Console
3434

3535
from directory_tree import display_tree
3636

@@ -41,7 +41,6 @@
4141
TritonCLIException,
4242
)
4343
from triton_cli.trt_llm.engine_config_parser import parse_and_substitute
44-
from triton_cli.trt_llm.builder import TRTLLMBuilder
4544

4645
from huggingface_hub import snapshot_download
4746
from huggingface_hub import utils as hf_utils
@@ -66,6 +65,7 @@
6665

6766
SOURCE_PREFIX_HUGGINGFACE = "hf:"
6867
SOURCE_PREFIX_NGC = "ngc:"
68+
SOURCE_PREFIX_LOCAL = "local:"
6969

7070
TRT_TEMPLATES_PATH = Path(__file__).parent / "templates" / "trt_llm"
7171

@@ -75,35 +75,6 @@
7575

7676
HF_TOKEN_PATH = Path.home() / ".cache" / "huggingface" / "token"
7777

78-
# TODO: Improve this flow and reduce hard-coded model check locations
79-
SUPPORTED_TRT_LLM_BUILDERS = {
80-
"facebook/opt-125m": {
81-
"hf_allow_patterns": ["*.bin", "*.json", "*.txt"],
82-
},
83-
"meta-llama/Llama-2-7b-hf": {
84-
"hf_allow_patterns": ["*.safetensors", "*.json"],
85-
},
86-
"meta-llama/Llama-2-7b-chat-hf": {
87-
"hf_allow_patterns": ["*.safetensors", "*.json"],
88-
},
89-
"meta-llama/Meta-Llama-3-8B": {
90-
"hf_allow_patterns": ["*.safetensors", "*.json"],
91-
},
92-
"meta-llama/Meta-Llama-3-8B-Instruct": {
93-
"hf_allow_patterns": ["*.safetensors", "*.json"],
94-
},
95-
"meta-llama/Meta-Llama-3.1-8B": {
96-
"hf_allow_patterns": ["*.safetensors", "*.json"],
97-
},
98-
"meta-llama/Meta-Llama-3.1-8B-Instruct": {
99-
"hf_allow_patterns": ["*.safetensors", "*.json"],
100-
},
101-
"gpt2": {
102-
"hf_allow_patterns": ["*.safetensors", "*.json"],
103-
"hf_ignore_patterns": ["onnx/*"],
104-
},
105-
}
106-
10778

10879
# NOTE: Thin wrapper around NGC CLI is a WAR for now.
10980
# TODO: Move out to generic files/interface for remote model stores
@@ -206,11 +177,19 @@ def add(
206177
backend = "tensorrtllm"
207178
# Local model path
208179
else:
209-
logger.debug("No supported prefix detected, assuming local path")
180+
if source.startswith(SOURCE_PREFIX_LOCAL):
181+
logger.debug("Local prefix detected, parsing local file path")
182+
else:
183+
logger.info(
184+
"No supported --source prefix detected, assuming local path"
185+
)
186+
210187
source_type = "local"
211188
model_path = Path(source)
212189
if not model_path.exists():
213-
raise TritonCLIException(f"{model_path} does not exist")
190+
raise TritonCLIException(
191+
f"Local file path '{model_path}' provided by --source does not exist"
192+
)
214193

215194
model_dir, version_dir = self.__create_model_repository(name, version, backend)
216195

@@ -349,23 +328,21 @@ def __generate_ngc_model(self, name: str, source: str):
349328
str(self.repo), name, engines_path, engines_path, "auto", dry_run=False
350329
)
351330

352-
def __generate_trtllm_model(self, name, huggingface_id):
353-
builder_info = SUPPORTED_TRT_LLM_BUILDERS.get(huggingface_id)
354-
if not builder_info:
355-
raise TritonCLIException(
356-
f"Building a TRT LLM engine for {huggingface_id} is not currently supported."
357-
)
358-
331+
def __generate_trtllm_model(self, name: str, huggingface_id: str):
359332
engines_path = ENGINE_DEST_PATH + "/" + name
360-
hf_download_path = ENGINE_DEST_PATH + "/" + name + "/hf_download"
361-
362333
engines = [engine for engine in Path(engines_path).glob("*.engine")]
363334
if engines:
364335
logger.warning(
365336
f"Found existing engine(s) at {engines_path}, skipping build."
366337
)
367338
else:
368-
self.__build_trtllm_engine(huggingface_id, hf_download_path, engines_path)
339+
# Run TRT-LLM build in a separate process to make sure it definitely
340+
# cleans up any GPU memory used when done.
341+
p = multiprocessing.Process(
342+
target=self.__build_trtllm_engine, args=(huggingface_id, engines_path)
343+
)
344+
p.start()
345+
p.join()
369346

370347
# NOTE: In every case, the TRT LLM template should be filled in with values.
371348
# If the model exists, the CLI will raise an exception when creating the model repo.
@@ -375,30 +352,25 @@ def __generate_trtllm_model(self, name, huggingface_id):
375352
triton_model_dir=str(self.repo),
376353
bls_model_name=name,
377354
engine_dir=engines_path,
378-
token_dir=hf_download_path,
355+
token_dir=engines_path,
379356
token_type="auto",
380357
dry_run=False,
381358
)
382359

383-
def __build_trtllm_engine(self, huggingface_id, hf_download_path, engines_path):
384-
builder_info = SUPPORTED_TRT_LLM_BUILDERS.get(huggingface_id)
385-
hf_allow_patterns = builder_info["hf_allow_patterns"]
386-
hf_ignore_patterns = builder_info.get("hf_ignore_patterns", None)
387-
self.__download_hf_model(
388-
huggingface_id,
389-
hf_download_path,
390-
allow_patterns=hf_allow_patterns,
391-
ignore_patterns=hf_ignore_patterns,
392-
)
360+
def __build_trtllm_engine(self, huggingface_id: str, engines_path: Path):
361+
from tensorrt_llm import LLM, BuildConfig
393362

394-
builder = TRTLLMBuilder(
395-
huggingface_id=huggingface_id,
396-
hf_download_path=hf_download_path,
397-
engine_output_path=engines_path,
398-
)
399-
console = Console()
400-
with console.status(f"Building TRT-LLM engine for {huggingface_id}..."):
401-
builder.build()
363+
# NOTE: Given config.json, can read from 'build_config' section and from_dict
364+
config = BuildConfig()
365+
# TODO: Expose more build args to user
366+
# TODO: Discuss LLM API BuildConfig defaults
367+
# config.max_input_len = 1024
368+
# config.max_seq_len = 8192
369+
# config.max_batch_size = 256
370+
371+
engine = LLM(huggingface_id, build_config=config)
372+
# TODO: Investigate if LLM is internally saving a copy to a temp dir
373+
engine.save(str(engines_path))
402374

403375
def __create_model_repository(
404376
self, name: str, version: int = 1, backend: str = None

0 commit comments

Comments
 (0)