diff --git a/src/triton_cli/repository.py b/src/triton_cli/repository.py index 7728d55..e17334f 100644 --- a/src/triton_cli/repository.py +++ b/src/triton_cli/repository.py @@ -185,7 +185,7 @@ def add( ) source_type = "local" - model_path = Path(source) + model_path = Path(source.replace(SOURCE_PREFIX_LOCAL, "")) if not model_path.exists(): raise TritonCLIException( f"Local file path '{model_path}' provided by --source does not exist" @@ -212,8 +212,15 @@ def add( # point to downloaded engines, etc. self.__generate_ngc_model(name, ngc_model_name) else: - logger.debug(f"Copying {model_path} to {version_dir}") - shutil.copy(model_path, version_dir) + if model_path.is_dir(): + logger.info(f"Copying model directory {model_path} to {version_dir}") + # If version_dir already exists, remove it first + if version_dir.exists(): + shutil.rmtree(version_dir) + shutil.copytree(model_path, version_dir) + else: + logger.info(f"Copying model file {model_path} to {version_dir}") + shutil.copy(model_path, version_dir) if verbose: self.list()