From a8d544b397d7eedba3952a5c8b9cd93f9e310b3c Mon Sep 17 00:00:00 2001 From: RinZ27 <222222878+RinZ27@users.noreply.github.com> Date: Fri, 27 Mar 2026 21:30:51 +0700 Subject: [PATCH] security(opt): enable weights_only=True by default for secure checkpoint loading Signed-off-by: RinZ27 <222222878+RinZ27@users.noreply.github.com> --- CHANGELOG.rst | 4 ++ examples/deepseek/quantize_to_nvfp4.py | 3 +- examples/llm_eval/modeling.py | 3 +- .../compute_hidden_states_trtllm.py | 3 +- .../sample_hidden_states.py | 3 +- examples/speculative_decoding/eagle_utils.py | 3 +- examples/speculative_decoding/main.py | 3 +- examples/vllm_serve/fakequant_worker.py | 3 +- modelopt/torch/export/distribute.py | 10 +-- modelopt/torch/export/model_config.py | 28 +++++++- modelopt/torch/opt/config.py | 8 +++ modelopt/torch/opt/conversion.py | 15 ++-- .../opt/plugins/mcore_dist_checkpointing.py | 23 ++++-- modelopt/torch/opt/plugins/megatron.py | 15 ++-- modelopt/torch/opt/plugins/peft.py | 10 ++- modelopt/torch/quantization/model_calib.py | 15 ++-- .../quantization/qtensor/base_qtensor.py | 8 ++- .../torch/quantization/qtensor/fp8_tensor.py | 4 ++ .../torch/quantization/qtensor/int4_tensor.py | 4 ++ .../torch/quantization/qtensor/int8_tensor.py | 4 ++ .../quantization/qtensor/mxfp4_tensor.py | 4 ++ .../quantization/qtensor/mxfp8_tensor.py | 4 ++ .../torch/quantization/qtensor/nf4_tensor.py | 4 ++ .../quantization/qtensor/nvfp4_tensor.py | 6 ++ modelopt/torch/utils/__init__.py | 1 + modelopt/torch/utils/serialization.py | 42 +++++++++++ tests/unit/torch/utils/test_serialization.py | 72 +++++++++++++++++++ 27 files changed, 261 insertions(+), 41 deletions(-) create mode 100644 modelopt/torch/utils/serialization.py create mode 100644 tests/unit/torch/utils/test_serialization.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index cc172bdcf0..fdfc5eb754 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -3,6 +3,10 @@ NVIDIA Model Optimizer Changelog 0.44 (2026-05-xx) ^^^^^^^^^^^^^^^^^ +**Backward Breaking Changes** + +- Changed the default of ``weights_only`` to ``True`` in ``torch.load`` for secure checkpoint loading. If you need to load a checkpoint that requires unpickling arbitrary objects, explicitly set ``weights_only=False``. + **New Features** - Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics. diff --git a/examples/deepseek/quantize_to_nvfp4.py b/examples/deepseek/quantize_to_nvfp4.py index af387fce5b..e4f26f7c97 100644 --- a/examples/deepseek/quantize_to_nvfp4.py +++ b/examples/deepseek/quantize_to_nvfp4.py @@ -1,3 +1,4 @@ +from modelopt.torch.utils.serialization import safe_load # Adapted from https://github.com/deepseek-ai/DeepSeek-V3/blob/2f7b80eecebf3d1c84da5a0d465f6639ea175012/inference/fp8_cast_bf16.py # MIT License @@ -98,7 +99,7 @@ def remove_quantization_config_from_original_config(export_dir: str) -> None: def load_and_preprocess_state_dict(modelopt_state_root, world_size=8): state_dict_list = [ - torch.load(f"{modelopt_state_root}/amax_dict_rank{rank}-mp{world_size}.pt") + safe_load(f"{modelopt_state_root}/amax_dict_rank{rank}-mp{world_size}.pt") for rank in range(world_size) ] diff --git a/examples/llm_eval/modeling.py b/examples/llm_eval/modeling.py index d06d055603..1496276ea1 100644 --- a/examples/llm_eval/modeling.py +++ b/examples/llm_eval/modeling.py @@ -1,3 +1,4 @@ +from modelopt.torch.utils.serialization import safe_load # Adapted from https://github.com/declare-lab/instruct-eval/blob/720e66f627369266ed1cfd74426666ec37e524bc/modeling.py # MIT License @@ -428,7 +429,7 @@ def load_quant( model.load_state_dict(safe_load(checkpoint), strict=False) else: - model.load_state_dict(torch.load(checkpoint), strict=False) + model.load_state_dict(safe_load(checkpoint), strict=False) if eval: quant.make_quant_attn(model) diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py index 0bf68e430f..56842993c0 100644 --- a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py +++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py @@ -1,3 +1,4 @@ +from modelopt.torch.utils.serialization import safe_load # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # @@ -219,7 +220,7 @@ async def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: if not trtllm_dumped_file.exists(): return False with open(trtllm_dumped_file, "rb") as f: - trtllm_dumped = torch.load(f) + trtllm_dumped = safe_load(f) assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, ( "TRTLLM dumped should be a list with one element" ) diff --git a/examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py b/examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py index 75a88969d3..9387e70244 100644 --- a/examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py +++ b/examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py @@ -1,3 +1,4 @@ +from modelopt.torch.utils.serialization import safe_load # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # @@ -60,7 +61,7 @@ def main(args: argparse.Namespace) -> None: ) for i, file in enumerate(sampled_files): - data = torch.load(file) + data = safe_load(file) expected_keys = [ "input_ids", "hidden_states", diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 1bc7df9810..e3d31f6a0e 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -41,6 +41,7 @@ ShardedDataset, VisionLanguageDataCollator, ) +from modelopt.torch.utils.serialization import safe_load try: import wandb @@ -71,7 +72,7 @@ def __len__(self): def __getitem__(self, i) -> dict[str, torch.Tensor]: try: - offline_data = torch.load(self.dumped_files[i]) + offline_data = safe_load(self.dumped_files[i]) except Exception as e: print( f"[ERROR] Failed to load file at index={i}, " diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 3369d399c2..5fa6b39c42 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -1,3 +1,4 @@ +from modelopt.torch.utils.serialization import safe_load # Adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/3783d18/train.py # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li @@ -234,7 +235,7 @@ def train(): raise FileNotFoundError( f"Draft vocab cache provided but not found: {data_args.draft_vocab_cache}" ) - model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache) + model.eagle_module.d2t = safe_load(data_args.draft_vocab_cache) print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.") else: raise Exception(f"{training_args.mode} is not supported!") diff --git a/examples/vllm_serve/fakequant_worker.py b/examples/vllm_serve/fakequant_worker.py index 772c6fe669..2cd29f324a 100644 --- a/examples/vllm_serve/fakequant_worker.py +++ b/examples/vllm_serve/fakequant_worker.py @@ -1,3 +1,4 @@ +from modelopt.torch.utils.serialization import safe_load # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # @@ -284,7 +285,7 @@ def calibrate_loop(model: Any = None) -> None: amax_file_path = quant_config["amax_file_path"] if amax_file_path: print(f"Loading amax values from {amax_file_path}") - saved_amax_dict = torch.load(amax_file_path) + saved_amax_dict = safe_load(amax_file_path) # convert amax keys to vLLM format if hasattr(self.model_runner.model, "hf_to_vllm_mapper"): saved_amax_dict = self.model_runner.model.hf_to_vllm_mapper.apply_dict(saved_amax_dict) diff --git a/modelopt/torch/export/distribute.py b/modelopt/torch/export/distribute.py index 4fe7be43ef..8f401be72f 100644 --- a/modelopt/torch/export/distribute.py +++ b/modelopt/torch/export/distribute.py @@ -25,6 +25,7 @@ import torch from modelopt.torch.utils import distributed as dist +from modelopt.torch.utils.serialization import safe_load from .model_config_utils import ( model_config_from_dict, @@ -41,7 +42,7 @@ class NFSWorkspace: communication nor barrier. It is users' responsibility to synchronize all ranks (local and remove processes). - This implementation uses `torch.save` and `torch.load` for serialization. + This implementation uses `torch.save` and `safe_load` for serialization. Args: workspace_path: the path to the NFS directory for postprocess cross rank communication. @@ -91,8 +92,7 @@ def read_configs_and_weights_from_rank( raise ValueError("NFSWorkspace is not initialized!") state_path = self._get_state_path(target_rank) if state_path.exists(): - # Security NOTE: weights_only=False is used here on ModelOpt-generated ckpt, not on untrusted user input - state = torch.load(state_path, map_location="cpu", weights_only=False) + state = safe_load(state_path, map_location="cpu") return state["config"], state["weight"] else: return None, None @@ -157,7 +157,7 @@ def get_tensors_parallel(tensor: torch.Tensor, ranks: list[int], group=None): tensors.append(tensor) else: shm = SharedMemory(name=f"rank_{rank}", create=False) - shared_tensor = torch.load(BytesIO(shm.buf)) + shared_tensor = safe_load(BytesIO(shm.buf)) tensors.append(shared_tensor) shm_readers.append(shm) try: @@ -276,7 +276,7 @@ def _get_weights_nbytes(weights_dict: dict[str, torch.Tensor]): if len_json != 0: config_dict = json.loads(shm.buf[8 : 8 + len_json].tobytes().decode()) - weights = torch.load(BytesIO(shm.buf[8 + len_json :])) + weights = safe_load(BytesIO(shm.buf[8 + len_json :])) restore_model_config(config_dict, weights) config = model_config_from_dict(config_dict) diff --git a/modelopt/torch/export/model_config.py b/modelopt/torch/export/model_config.py index 4826d06391..60b6c61928 100755 --- a/modelopt/torch/export/model_config.py +++ b/modelopt/torch/export/model_config.py @@ -631,5 +631,31 @@ def num_kv_heads(self): @property def hidden_act(self): - """Returns the hidden_act of the model.""" + \"\"\"Returns the hidden_act of the model.\"\"\" return self.layers[0].mlp.hidden_act + + + # Register all config classes as safe globals + from modelopt.torch.utils.serialization import add_modelopt_safe_globals + + add_modelopt_safe_globals( + [ + EmbeddingConfig, + LayernormConfig, + LinearConfig, + LinearActConfig, + ConvConfig, + QKVConfig, + RelativeAttentionTableConfig, + AttentionConfig, + MLPConfig, + ExpertConfig, + RgLruConfig, + RecurrentConfig, + MOEConfig, + DecoderLayerConfig, + MedusaHeadConfig, + ModelConfig, + ] + ) + diff --git a/modelopt/torch/opt/config.py b/modelopt/torch/opt/config.py index 032b9fe6bf..9794477de2 100644 --- a/modelopt/torch/opt/config.py +++ b/modelopt/torch/opt/config.py @@ -20,6 +20,7 @@ from collections.abc import Callable, ItemsView, Iterator, KeysView, ValuesView from typing import Any, TypeAlias +import torch from pydantic import ( BaseModel, Field, @@ -65,6 +66,13 @@ class ModeloptBaseConfig(BaseModel): model_config = PyDanticConfigDict(extra="forbid", validate_assignment=True) + def __init_subclass__(cls, **kwargs: Any) -> None: + """Register the config class as a safe global for torch serialization.""" + super().__init_subclass__(**kwargs) + + # Register the class as a safe global for torch serialization + torch.serialization.add_safe_globals([cls]) + def model_dump(self, **kwargs): """Dump the config to a dictionary with aliases and no warnings by default.""" kwargs = {"by_alias": True, "warnings": False, **kwargs} diff --git a/modelopt/torch/opt/conversion.py b/modelopt/torch/opt/conversion.py index 6ec7a17298..9f29c1d3e4 100644 --- a/modelopt/torch/opt/conversion.py +++ b/modelopt/torch/opt/conversion.py @@ -35,6 +35,7 @@ from modelopt import __version__ from modelopt.torch.utils import ModelLike, init_model_from_model_like, unwrap_model +from modelopt.torch.utils.serialization import safe_load from .config import ConfigDict, ModeloptBaseConfig from .mode import ( @@ -523,12 +524,8 @@ def load_modelopt_state(modelopt_state_path: str | os.PathLike, **kwargs) -> dic Returns: A modelopt state dictionary describing the modifications to the model. """ - # Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input - kwargs.setdefault("weights_only", False) kwargs.setdefault("map_location", "cpu") - # TODO: Add some validation to ensure the file is a valid modelopt state file. - modelopt_state = torch.load(modelopt_state_path, **kwargs) - return modelopt_state + return safe_load(modelopt_state_path, **kwargs) def restore_from_modelopt_state( @@ -550,7 +547,9 @@ def restore_from_modelopt_state( # Restore the previously saved modelopt state followed by model weights mto.restore_from_modelopt_state(model, modelopt_state_path="modelopt_state.pt") - model.load_state_dict(torch.load("model_weights.pt"), ...) # Load the model weights + model.load_state_dict( + torch.load("model_weights.pt", weights_only=True), ... + ) # Load the model weights If you want to restore the model weights and the modelopt state with saved scales, please use :meth:`mto.restore()`. @@ -628,8 +627,8 @@ def restore(model: ModelLike, f: str | os.PathLike | BinaryIO, **kwargs) -> nn.M # load checkpoint kwargs.setdefault("map_location", "cpu") - kwargs.setdefault("weights_only", False) - objs = torch.load(f, **kwargs) + # Security NOTE: weights_only=True is used here on ModelOpt-generated checkpoints + objs = safe_load(f, **kwargs) # restore model architecture model_restored = restore_from_modelopt_state(model, objs["modelopt_state"]) diff --git a/modelopt/torch/opt/plugins/mcore_dist_checkpointing.py b/modelopt/torch/opt/plugins/mcore_dist_checkpointing.py index 16ace511fe..4c5509e106 100644 --- a/modelopt/torch/opt/plugins/mcore_dist_checkpointing.py +++ b/modelopt/torch/opt/plugins/mcore_dist_checkpointing.py @@ -21,6 +21,7 @@ from pathlib import Path from typing import Any +import megatron.core as mcore import torch from megatron.core import dist_checkpointing, mpu from megatron.core.dist_checkpointing.serialization import get_default_load_sharded_strategy @@ -32,6 +33,7 @@ import modelopt.torch.opt as mto import modelopt.torch.utils.distributed as dist from modelopt.torch.utils.network import SUPPORTED_WRAPPERS +from modelopt.torch.utils.serialization import safe_load SUPPORTED_WRAPPERS[Float16Module] = "module" @@ -131,6 +133,14 @@ def save_sharded_modelopt_state( os.makedirs(modelopt_checkpoint_name, exist_ok=True) modelopt_state = copy.deepcopy(mto.modelopt_state(model[0])) remove_per_module_state(modelopt_state) + + # Persist metadata for MCore version and sharded layout + modelopt_state["mcore_metadata"] = { + "version": mcore.__version__, + "singleton_local_shards": True, # Consistent with _MegatronMLP.sharded_state_dict() + "sharded_strategy": sharded_strategy, + } + dist_checkpointing.save(modelopt_state, modelopt_checkpoint_name, sharded_strategy) @@ -156,7 +166,8 @@ def _load_extra_state_from_sharded_checkpoint( is set to `True` (was not set before) in megatron-core-0.15.0. This flag affects the sharded state_dict format and must be consistent between saving and loading. """ - sharded_state_dict = model.sharded_state_dict(prefix=prefix) + # sharded_state_dict() should be called with metadata if provided to maintain consistency + sharded_state_dict = model.sharded_state_dict(prefix=prefix, metadata=metadata) extra_sharded_state_dict = {k: v for k, v in sharded_state_dict.items() if "_extra_state" in k} extra_state_dict = dist_checkpointing.load( extra_sharded_state_dict, @@ -203,11 +214,15 @@ def restore_sharded_modelopt_state( return # Loading the common modelopt_state (replicated on all ranks) - # Security NOTE: weights_only=False is used here on NVIDIA-generated file, not on untrusted user input - common_modelopt_state = torch.load( - modelopt_checkpoint_name + "/" + COMMON_STATE_FNAME, weights_only=False + common_modelopt_state = safe_load( + modelopt_checkpoint_name + "/" + COMMON_STATE_FNAME, map_location="cpu" ) + # Try to retrieve metadata from checkpoint + mcore_metadata = common_modelopt_state.get("mcore_metadata", None) + if metadata is None: + metadata = mcore_metadata + modelopt_load_version = common_modelopt_state["modelopt_version"] print(f"nvidia-modelopt ckpt/inst version: {modelopt_load_version}/{modelopt.__version__}") diff --git a/modelopt/torch/opt/plugins/megatron.py b/modelopt/torch/opt/plugins/megatron.py index 761e8d9a4c..5a8d58d504 100644 --- a/modelopt/torch/opt/plugins/megatron.py +++ b/modelopt/torch/opt/plugins/megatron.py @@ -15,8 +15,8 @@ """Support quantization and save/resore for Megatron.""" import contextlib -import pickle # nosec import types +from io import BytesIO from typing import Any import megatron.core.transformer.mlp as megatron_mlp @@ -24,6 +24,8 @@ import torch from megatron.core.parallel_state import get_data_parallel_group +from modelopt.torch.utils.serialization import safe_load + from ..dynamic import DynamicModule @@ -82,8 +84,10 @@ def _modelopt_get_extra_state(self): # Serialize state into byte tensor torch.cuda.synchronize() - state_serialized = bytearray(pickle.dumps(extra_state)) # nosec - state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8) + # Use torch.save for serialization to match safe_load + buffer = BytesIO() + torch.save(extra_state, buffer) + state_serialized = torch.frombuffer(buffer.getvalue(), dtype=torch.uint8) return state_serialized @@ -102,10 +106,7 @@ def _modelopt_set_extra_state(self, state: Any): if state.numel() == 0: return # Default format: byte tensor with pickled data - # - # TODO: possible deserialization improvement - # https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/serialization.py - extra_state = pickle.loads(state.detach().cpu().numpy().tobytes()) # nosec + extra_state = safe_load(state.detach().cpu().numpy().tobytes(), map_location="cpu") else: raise RuntimeError("Unsupported extra_state format.") diff --git a/modelopt/torch/opt/plugins/peft.py b/modelopt/torch/opt/plugins/peft.py index de1218917f..a475e9f428 100644 --- a/modelopt/torch/opt/plugins/peft.py +++ b/modelopt/torch/opt/plugins/peft.py @@ -22,6 +22,7 @@ from peft import PeftModel from modelopt.torch.utils import get_unwrapped_name, print_rank_0 +from modelopt.torch.utils.serialization import safe_load from ..conversion import ModeloptStateManager, modelopt_state, restore_from_modelopt_state from .huggingface import register_for_patching @@ -83,11 +84,14 @@ def _new_load_adapter(self, model_id, adapter_name, *args, **kwargs): if os.path.isfile(_get_quantizer_state_save_path(model_id)): from modelopt.torch.quantization.nn import TensorQuantizer - # Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input - quantizer_state_dict = torch.load( - _get_quantizer_state_save_path(model_id), map_location="cpu", weights_only=False + # Security NOTE: weights_only=False is required here because quantizer_state_dict contains + # custom ModelOpt types like QTensorWrapper which are pickled. These files are expected + # to be generated by the PEFT plugin, not from untrusted sources. + quantizer_state_dict = safe_load( + _get_quantizer_state_save_path(model_id), map_location=\"cpu\", weights_only=False ) for name, module in self.named_modules(): + if isinstance(module, TensorQuantizer): module.load_state_dict(quantizer_state_dict[get_unwrapped_name(name, self)]) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index ed57ea3fc7..4823d04816 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1729,12 +1729,17 @@ def initialize_hessian_state(tensor_mapping): "n_samples": 0, } - def load_hessian_state(path, tensor_mapping): - """Load hessian state from file.""" - print_rank_0(f"Loading hessian state from {path}") - loaded_state = torch.load(path, map_location="cpu") + from modelopt.torch.utils.serialization import safe_load + + ... + + def load_hessian_state(path, tensor_mapping): + \"\"\"Load hessian state from file.\"\"\" + print_rank_0(f\"Loading hessian state from {path}\") + loaded_state = safe_load(path, map_location=\"cpu\") + + for name, (shape, device) in tensor_mapping.items(): - for name, (shape, device) in tensor_mapping.items(): if name not in loaded_state: raise KeyError(f"Layer '{name}' not found in loaded hessian state") diff --git a/modelopt/torch/quantization/qtensor/base_qtensor.py b/modelopt/torch/quantization/qtensor/base_qtensor.py index d5a9a4269e..38facd3f96 100644 --- a/modelopt/torch/quantization/qtensor/base_qtensor.py +++ b/modelopt/torch/quantization/qtensor/base_qtensor.py @@ -116,7 +116,7 @@ def to(self, *args, **kwargs): changing_device, changing_dtype, *_ = torch._C._nn._parse_to(*args, **kwargs) if changing_device: self.data = self.data.to(device=changing_device) - dtype = changing_dtype if changing_dtype else self.metadata["dtype"] + dtype = changing_dtype or self.metadata["dtype"] return QTensorWrapper( self.metadata["qtensor_class"](self.metadata["shape"], dtype, self.data) ) @@ -227,3 +227,9 @@ def _compress_and_update_module_weight(module): if name != "": with fsdp2_aware_weight_update(module, m): _compress_and_update_module_weight(m) + + +# Register QTensor types as safe globals +from modelopt.torch.utils.serialization import add_modelopt_safe_globals + +add_modelopt_safe_globals([BaseQuantizedTensor, QTensorWrapper]) diff --git a/modelopt/torch/quantization/qtensor/fp8_tensor.py b/modelopt/torch/quantization/qtensor/fp8_tensor.py index a8cce4415b..c796f51b81 100644 --- a/modelopt/torch/quantization/qtensor/fp8_tensor.py +++ b/modelopt/torch/quantization/qtensor/fp8_tensor.py @@ -149,3 +149,7 @@ def dequantize(self, dtype: torch.dtype = None, **kwarg): scales = scales.to(self._quantized_data.device) return (quantized_data.to(dtype) * scales.to(dtype))[slices] + +# Register FP8QTensor as safe global +from modelopt.torch.utils.serialization import add_modelopt_safe_globals +add_modelopt_safe_globals([FP8QTensor]) diff --git a/modelopt/torch/quantization/qtensor/int4_tensor.py b/modelopt/torch/quantization/qtensor/int4_tensor.py index f78d12aad7..68072ea8a7 100644 --- a/modelopt/torch/quantization/qtensor/int4_tensor.py +++ b/modelopt/torch/quantization/qtensor/int4_tensor.py @@ -128,3 +128,7 @@ def dequantize(self, dtype: torch.dtype = None, **kwarg): .reshape(self.metadata["shape"]) .to(dtype) ) + +# Register INT4QTensor as safe global +from modelopt.torch.utils.serialization import add_modelopt_safe_globals +add_modelopt_safe_globals([INT4QTensor]) diff --git a/modelopt/torch/quantization/qtensor/int8_tensor.py b/modelopt/torch/quantization/qtensor/int8_tensor.py index 890b88e9ae..1845be759c 100644 --- a/modelopt/torch/quantization/qtensor/int8_tensor.py +++ b/modelopt/torch/quantization/qtensor/int8_tensor.py @@ -122,3 +122,7 @@ def dequantize(self, dtype: torch.dtype = None, **kwarg): scales = scales.to(self._quantized_data.device) return (self._quantized_data.view(torch.int8).to(dtype) * scales.to(dtype))[slices] + +# Register INT8QTensor as safe global +from modelopt.torch.utils.serialization import add_modelopt_safe_globals +add_modelopt_safe_globals([INT8QTensor]) diff --git a/modelopt/torch/quantization/qtensor/mxfp4_tensor.py b/modelopt/torch/quantization/qtensor/mxfp4_tensor.py index 022825e406..424d5512fc 100644 --- a/modelopt/torch/quantization/qtensor/mxfp4_tensor.py +++ b/modelopt/torch/quantization/qtensor/mxfp4_tensor.py @@ -142,3 +142,7 @@ def unfuse_uint8_to_uint4(x): # Reshape back to the original shape return x_float.reshape(original_shape).to(dtype) + +# Register MXFP4QTensor as safe global +from modelopt.torch.utils.serialization import add_modelopt_safe_globals +add_modelopt_safe_globals([MXFP4QTensor]) diff --git a/modelopt/torch/quantization/qtensor/mxfp8_tensor.py b/modelopt/torch/quantization/qtensor/mxfp8_tensor.py index 846a95ffcd..ef58692d52 100644 --- a/modelopt/torch/quantization/qtensor/mxfp8_tensor.py +++ b/modelopt/torch/quantization/qtensor/mxfp8_tensor.py @@ -260,3 +260,7 @@ def dequantize(self, dtype: torch.dtype = None, **kwargs) -> torch.Tensor: dequantized = dequantized[..., : original_shape[-1]] return dequantized.to(dtype) + +# Register MXFP8QTensor as safe global +from modelopt.torch.utils.serialization import add_modelopt_safe_globals +add_modelopt_safe_globals([MXFP8QTensor]) diff --git a/modelopt/torch/quantization/qtensor/nf4_tensor.py b/modelopt/torch/quantization/qtensor/nf4_tensor.py index 5b647d4dab..74e90afb92 100644 --- a/modelopt/torch/quantization/qtensor/nf4_tensor.py +++ b/modelopt/torch/quantization/qtensor/nf4_tensor.py @@ -197,3 +197,7 @@ def dequantize(self, dtype: torch.dtype = None, **kwarg): .reshape(self.metadata["shape"]) .to(dtype) ) + +# Register NF4QTensor as safe global +from modelopt.torch.utils.serialization import add_modelopt_safe_globals +add_modelopt_safe_globals([NF4QTensor]) diff --git a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py index 6ff31424c7..faaeee5ffc 100644 --- a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py +++ b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py @@ -374,3 +374,9 @@ def _unpack_tensor(input: torch.Tensor): (*tuple(deq_data.shape[:-1]), -1, block_size) ) * per_block_scale.unsqueeze(-1) return deq_data.reshape(self.metadata["shape"]).to(dtype) + + +# Register NVFP4QTensor as safe global +from modelopt.torch.utils.serialization import add_modelopt_safe_globals + +add_modelopt_safe_globals([NVFP4QTensor]) diff --git a/modelopt/torch/utils/__init__.py b/modelopt/torch/utils/__init__.py index f026e747a8..51d02248c1 100644 --- a/modelopt/torch/utils/__init__.py +++ b/modelopt/torch/utils/__init__.py @@ -26,5 +26,6 @@ from .perf import * from .regex import * from .robust_json import * +from .serialization import * from .tensor import * from .vlm_dataset_utils import * diff --git a/modelopt/torch/utils/serialization.py b/modelopt/torch/utils/serialization.py new file mode 100644 index 0000000000..cecd4a9bac --- /dev/null +++ b/modelopt/torch/utils/serialization.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Serialization utilities for secure checkpoint loading.""" + +import os +from io import BytesIO +from typing import Any, BinaryIO + +import torch + + +def safe_load(f: str | os.PathLike | BinaryIO | bytes, **kwargs) -> Any: + \"\"\"Load a checkpoint securely using weights_only=True by default.\"\"\" + kwargs.setdefault(\"weights_only\", True) + + if isinstance(f, (bytes, bytearray)): + f = BytesIO(f) + + return torch.load(f, **kwargs) # nosec B614 + + +def add_modelopt_safe_globals(classes: list[type]) -> None: + \"\"\"Register ModelOpt classes as safe globals for torch serialization. + + This is required when weights_only=True is used in torch.load(). + \"\"\" + if hasattr(torch.serialization, \"add_safe_globals\"): + torch.serialization.add_safe_globals(classes) + diff --git a/tests/unit/torch/utils/test_serialization.py b/tests/unit/torch/utils/test_serialization.py new file mode 100644 index 0000000000..a821d30a41 --- /dev/null +++ b/tests/unit/torch/utils/test_serialization.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Modelopt's serialization utilities.""" + +from io import BytesIO + +import torch + +from modelopt.torch.opt.config import ModeloptBaseConfig +from modelopt.torch.utils.serialization import safe_load + + +class MockConfig(ModeloptBaseConfig): + """A mock configuration class for testing serialization.""" + + name: str = "mock" + + +def test_safe_load_with_modelopt_config(): + """Verify that safe_load can handle ModeloptBaseConfig subclasses with weights_only=True.""" + config = MockConfig(name="test_serialization") + state = {"config": config} + + buffer = BytesIO() + torch.save(state, buffer) + data = buffer.getvalue() + + # safe_load defaults to weights_only=True + loaded_state = safe_load(data) + + assert isinstance(loaded_state["config"], MockConfig) + assert loaded_state["config"].name == "test_serialization" + + +def test_safe_load_basic_types(): + """Verify that safe_load can handle basic types (standard torch.load functionality).""" + state = {"t": torch.ones(2), "v": [1, 2, 3], "d": {"a": 1}} + + buffer = BytesIO() + torch.save(state, buffer) + data = buffer.getvalue() + + loaded_state = safe_load(data) + + assert torch.allclose(loaded_state["t"], torch.ones(2)) + assert loaded_state["v"] == [1, 2, 3] + assert loaded_state["d"]["a"] == 1 + + +def test_safe_load_with_path(tmp_path): + """Verify that safe_load can handle file paths.""" + state = {"data": 42} + file_path = tmp_path / "test.pt" + + torch.save(state, file_path) + + loaded_state = safe_load(file_path) + + assert loaded_state["data"] == 42