diff --git a/dlclibrary/dlcmodelzoo/modelzoo_download.py b/dlclibrary/dlcmodelzoo/modelzoo_download.py index 9da4c6b..12dff74 100644 --- a/dlclibrary/dlcmodelzoo/modelzoo_download.py +++ b/dlclibrary/dlcmodelzoo/modelzoo_download.py @@ -94,7 +94,7 @@ def get_available_datasets() -> list[str]: def get_available_detectors(dataset: str) -> list[str]: - """ Only for PyTorch models. + """Only for PyTorch models. Returns: The detectors available for the dataset. @@ -103,7 +103,7 @@ def get_available_detectors(dataset: str) -> list[str]: def get_available_models(dataset: str) -> list[str]: - """ Only for PyTorch models. + """Only for PyTorch models. Returns: The pose models available for the dataset. @@ -139,19 +139,39 @@ def download_huggingface_model( model_name: str, target_dir: str = ".", remove_hf_folder: bool = True, - rename_mapping: dict | None = None, + rename_mapping: str | dict | None = None, ): """ Downloads a DeepLabCut Model Zoo Project from Hugging Face. Args: - model_name (str): Name of the ModelZoo model. + model_name (str): + Name of the ModelZoo model. For visualizations, see http://www.mackenziemathislab.org/dlc-modelzoo. - target_dir (str): Directory where the model weights and pose_cfg.yaml file will be stored. - remove_hf_folder (bool, optional): Whether to remove the directory structure provided by HuggingFace - after downloading and decompressing the data into DeepLabCut format. Defaults to True. - rename_mapping (dict, optional): A dictionary to rename the downloaded file. - If None, the original filename is used. Defaults to None. + target_dir (str, optional): + Target directory where the model weights will be stored. + Defaults to the current directory. + remove_hf_folder (bool, optional): + Whether to remove the directory structure created by HuggingFace + after downloading and decompressing the data into DeepLabCut format. + Defaults to True. + rename_mapping (dict | str | None, optional): + - If a dictionary, it should map the original Hugging Face filenames + to new filenames (e.g. {"snapshot-12345.tar.gz": "mymodel.tar.gz"}). + - If a string, it is interpreted as the new name for the downloaded file + - If None, the original filename is used. + Defaults to None. + + Examples: + >>> # Download without renaming, keep original filename + download_huggingface_model("superanimal_bird_resnet_50", remove_hf_folder=False) + + >>> # Download and rename by specifying the new name directly + download_huggingface_model( + model_name="superanimal_humanbody_rtmpose_x", + target_dir="/path/to/,y/checkpoints", + rename_mapping="superanimal_humanbody_rtmpose_x.pt" + ) """ net_urls = _load_model_names() if model_name not in net_urls: @@ -180,6 +200,10 @@ def download_huggingface_model( path_ = os.path.join(target_dir, hf_folder, "snapshots") commit = os.listdir(path_)[0] file_name = os.path.join(path_, commit, targzfn) + + if isinstance(rename_mapping, str): + rename_mapping = {targzfn: rename_mapping} + _handle_downloaded_file(file_name, target_dir, rename_mapping) if remove_hf_folder: