From 0ce74732b828becc77cf07d560917789706fd802 Mon Sep 17 00:00:00 2001 From: maximpavliv Date: Tue, 15 Jul 2025 11:44:16 +0200 Subject: [PATCH 1/3] download_huggingface_model: allow passing str as rename mapping --- dlclibrary/dlcmodelzoo/modelzoo_download.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dlclibrary/dlcmodelzoo/modelzoo_download.py b/dlclibrary/dlcmodelzoo/modelzoo_download.py index 9da4c6b..3517630 100644 --- a/dlclibrary/dlcmodelzoo/modelzoo_download.py +++ b/dlclibrary/dlcmodelzoo/modelzoo_download.py @@ -139,7 +139,7 @@ def download_huggingface_model( model_name: str, target_dir: str = ".", remove_hf_folder: bool = True, - rename_mapping: dict | None = None, + rename_mapping: str | dict | None = None, ): """ Downloads a DeepLabCut Model Zoo Project from Hugging Face. @@ -180,6 +180,10 @@ def download_huggingface_model( path_ = os.path.join(target_dir, hf_folder, "snapshots") commit = os.listdir(path_)[0] file_name = os.path.join(path_, commit, targzfn) + + if isinstance(rename_mapping, str): + rename_mapping = {targzfn: rename_mapping} + _handle_downloaded_file(file_name, target_dir, rename_mapping) if remove_hf_folder: From 61853d0aa50fc31c337739016778b3b08a50150d Mon Sep 17 00:00:00 2001 From: maximpavliv Date: Tue, 2 Sep 2025 15:42:59 +0200 Subject: [PATCH 2/3] Improve docstring --- dlclibrary/dlcmodelzoo/modelzoo_download.py | 28 ++++++++++++++++----- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/dlclibrary/dlcmodelzoo/modelzoo_download.py b/dlclibrary/dlcmodelzoo/modelzoo_download.py index 3517630..6249040 100644 --- a/dlclibrary/dlcmodelzoo/modelzoo_download.py +++ b/dlclibrary/dlcmodelzoo/modelzoo_download.py @@ -145,13 +145,29 @@ def download_huggingface_model( Downloads a DeepLabCut Model Zoo Project from Hugging Face. Args: - model_name (str): Name of the ModelZoo model. + model_name (str): + Name of the ModelZoo model. For visualizations, see http://www.mackenziemathislab.org/dlc-modelzoo. - target_dir (str): Directory where the model weights and pose_cfg.yaml file will be stored. - remove_hf_folder (bool, optional): Whether to remove the directory structure provided by HuggingFace - after downloading and decompressing the data into DeepLabCut format. Defaults to True. - rename_mapping (dict, optional): A dictionary to rename the downloaded file. - If None, the original filename is used. Defaults to None. + target_dir (str, optional): + Target directory where the model weights will be stored. + Defaults to the current directory. + remove_hf_folder (bool, optional): + Whether to remove the directory structure created by HuggingFace + after downloading and decompressing the data into DeepLabCut format. + Defaults to True. + rename_mapping (dict | str | None, optional): + - If a dictionary, it should map the original Hugging Face filenames + to new filenames (e.g. {"snapshot-12345.tar.gz": "mymodel.tar.gz"}). + - If a string, it is interpreted as the new name for the downloaded file + - If None, the original filename is used. + Defaults to None. + + Examples: + >>> # Download without renaming, keep original filename + download_huggingface_model("superanimal_bird_resnet_50", remove_hf_folder=False) + + >>> # Download and rename by specifying the new name directly + download_huggingface_model("superanimal_humanbody_rtmpose_x", target_dir="/path/to/,y/checkpoints", rename_mapping="superanimal_humanbody_rtmpose_x.pt") """ net_urls = _load_model_names() if model_name not in net_urls: From 1c78f82b5c77d3620938febf0e4c1ee003d92ec9 Mon Sep 17 00:00:00 2001 From: maximpavliv Date: Tue, 2 Sep 2025 15:44:47 +0200 Subject: [PATCH 3/3] Formatting --- dlclibrary/dlcmodelzoo/modelzoo_download.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/dlclibrary/dlcmodelzoo/modelzoo_download.py b/dlclibrary/dlcmodelzoo/modelzoo_download.py index 6249040..12dff74 100644 --- a/dlclibrary/dlcmodelzoo/modelzoo_download.py +++ b/dlclibrary/dlcmodelzoo/modelzoo_download.py @@ -94,7 +94,7 @@ def get_available_datasets() -> list[str]: def get_available_detectors(dataset: str) -> list[str]: - """ Only for PyTorch models. + """Only for PyTorch models. Returns: The detectors available for the dataset. @@ -103,7 +103,7 @@ def get_available_detectors(dataset: str) -> list[str]: def get_available_models(dataset: str) -> list[str]: - """ Only for PyTorch models. + """Only for PyTorch models. Returns: The pose models available for the dataset. @@ -167,7 +167,11 @@ def download_huggingface_model( download_huggingface_model("superanimal_bird_resnet_50", remove_hf_folder=False) >>> # Download and rename by specifying the new name directly - download_huggingface_model("superanimal_humanbody_rtmpose_x", target_dir="/path/to/,y/checkpoints", rename_mapping="superanimal_humanbody_rtmpose_x.pt") + download_huggingface_model( + model_name="superanimal_humanbody_rtmpose_x", + target_dir="/path/to/,y/checkpoints", + rename_mapping="superanimal_humanbody_rtmpose_x.pt" + ) """ net_urls = _load_model_names() if model_name not in net_urls: