2929import shutil
3030import logging
3131import subprocess
32+ import multiprocessing
3233from pathlib import Path
33- from rich .console import Console
3434
3535from directory_tree import display_tree
3636
4141 TritonCLIException ,
4242)
4343from triton_cli .trt_llm .engine_config_parser import parse_and_substitute
44- from triton_cli .trt_llm .builder import TRTLLMBuilder
4544
4645from huggingface_hub import snapshot_download
4746from huggingface_hub import utils as hf_utils
6665
6766SOURCE_PREFIX_HUGGINGFACE = "hf:"
6867SOURCE_PREFIX_NGC = "ngc:"
68+ SOURCE_PREFIX_LOCAL = "local:"
6969
7070TRT_TEMPLATES_PATH = Path (__file__ ).parent / "templates" / "trt_llm"
7171
7575
7676HF_TOKEN_PATH = Path .home () / ".cache" / "huggingface" / "token"
7777
78- # TODO: Improve this flow and reduce hard-coded model check locations
79- SUPPORTED_TRT_LLM_BUILDERS = {
80- "facebook/opt-125m" : {
81- "hf_allow_patterns" : ["*.bin" , "*.json" , "*.txt" ],
82- },
83- "meta-llama/Llama-2-7b-hf" : {
84- "hf_allow_patterns" : ["*.safetensors" , "*.json" ],
85- },
86- "meta-llama/Llama-2-7b-chat-hf" : {
87- "hf_allow_patterns" : ["*.safetensors" , "*.json" ],
88- },
89- "meta-llama/Meta-Llama-3-8B" : {
90- "hf_allow_patterns" : ["*.safetensors" , "*.json" ],
91- },
92- "meta-llama/Meta-Llama-3-8B-Instruct" : {
93- "hf_allow_patterns" : ["*.safetensors" , "*.json" ],
94- },
95- "meta-llama/Meta-Llama-3.1-8B" : {
96- "hf_allow_patterns" : ["*.safetensors" , "*.json" ],
97- },
98- "meta-llama/Meta-Llama-3.1-8B-Instruct" : {
99- "hf_allow_patterns" : ["*.safetensors" , "*.json" ],
100- },
101- "gpt2" : {
102- "hf_allow_patterns" : ["*.safetensors" , "*.json" ],
103- "hf_ignore_patterns" : ["onnx/*" ],
104- },
105- }
106-
10778
10879# NOTE: Thin wrapper around NGC CLI is a WAR for now.
10980# TODO: Move out to generic files/interface for remote model stores
@@ -206,11 +177,19 @@ def add(
206177 backend = "tensorrtllm"
207178 # Local model path
208179 else :
209- logger .debug ("No supported prefix detected, assuming local path" )
180+ if source .startswith (SOURCE_PREFIX_LOCAL ):
181+ logger .debug ("Local prefix detected, parsing local file path" )
182+ else :
183+ logger .info (
184+ "No supported --source prefix detected, assuming local path"
185+ )
186+
210187 source_type = "local"
211188 model_path = Path (source )
212189 if not model_path .exists ():
213- raise TritonCLIException (f"{ model_path } does not exist" )
190+ raise TritonCLIException (
191+ f"Local file path '{ model_path } ' provided by --source does not exist"
192+ )
214193
215194 model_dir , version_dir = self .__create_model_repository (name , version , backend )
216195
@@ -349,23 +328,21 @@ def __generate_ngc_model(self, name: str, source: str):
349328 str (self .repo ), name , engines_path , engines_path , "auto" , dry_run = False
350329 )
351330
352- def __generate_trtllm_model (self , name , huggingface_id ):
353- builder_info = SUPPORTED_TRT_LLM_BUILDERS .get (huggingface_id )
354- if not builder_info :
355- raise TritonCLIException (
356- f"Building a TRT LLM engine for { huggingface_id } is not currently supported."
357- )
358-
331+ def __generate_trtllm_model (self , name : str , huggingface_id : str ):
359332 engines_path = ENGINE_DEST_PATH + "/" + name
360- hf_download_path = ENGINE_DEST_PATH + "/" + name + "/hf_download"
361-
362333 engines = [engine for engine in Path (engines_path ).glob ("*.engine" )]
363334 if engines :
364335 logger .warning (
365336 f"Found existing engine(s) at { engines_path } , skipping build."
366337 )
367338 else :
368- self .__build_trtllm_engine (huggingface_id , hf_download_path , engines_path )
339+ # Run TRT-LLM build in a separate process to make sure it definitely
340+ # cleans up any GPU memory used when done.
341+ p = multiprocessing .Process (
342+ target = self .__build_trtllm_engine , args = (huggingface_id , engines_path )
343+ )
344+ p .start ()
345+ p .join ()
369346
370347 # NOTE: In every case, the TRT LLM template should be filled in with values.
371348 # If the model exists, the CLI will raise an exception when creating the model repo.
@@ -375,30 +352,25 @@ def __generate_trtllm_model(self, name, huggingface_id):
375352 triton_model_dir = str (self .repo ),
376353 bls_model_name = name ,
377354 engine_dir = engines_path ,
378- token_dir = hf_download_path ,
355+ token_dir = engines_path ,
379356 token_type = "auto" ,
380357 dry_run = False ,
381358 )
382359
383- def __build_trtllm_engine (self , huggingface_id , hf_download_path , engines_path ):
384- builder_info = SUPPORTED_TRT_LLM_BUILDERS .get (huggingface_id )
385- hf_allow_patterns = builder_info ["hf_allow_patterns" ]
386- hf_ignore_patterns = builder_info .get ("hf_ignore_patterns" , None )
387- self .__download_hf_model (
388- huggingface_id ,
389- hf_download_path ,
390- allow_patterns = hf_allow_patterns ,
391- ignore_patterns = hf_ignore_patterns ,
392- )
360+ def __build_trtllm_engine (self , huggingface_id : str , engines_path : Path ):
361+ from tensorrt_llm import LLM , BuildConfig
393362
394- builder = TRTLLMBuilder (
395- huggingface_id = huggingface_id ,
396- hf_download_path = hf_download_path ,
397- engine_output_path = engines_path ,
398- )
399- console = Console ()
400- with console .status (f"Building TRT-LLM engine for { huggingface_id } ..." ):
401- builder .build ()
363+ # NOTE: Given config.json, can read from 'build_config' section and from_dict
364+ config = BuildConfig ()
365+ # TODO: Expose more build args to user
366+ # TODO: Discuss LLM API BuildConfig defaults
367+ # config.max_input_len = 1024
368+ # config.max_seq_len = 8192
369+ # config.max_batch_size = 256
370+
371+ engine = LLM (huggingface_id , build_config = config )
372+ # TODO: Investigate if LLM is internally saving a copy to a temp dir
373+ engine .save (str (engines_path ))
402374
403375 def __create_model_repository (
404376 self , name : str , version : int = 1 , backend : str = None
0 commit comments