Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ NVIDIA Model Optimizer Changelog
0.44 (2026-05-xx)
^^^^^^^^^^^^^^^^^

**Backward Breaking Changes**

- Changed the default of ``weights_only`` to ``True`` in ``torch.load`` for secure checkpoint loading. If you need to load a checkpoint that requires unpickling arbitrary objects, explicitly set ``weights_only=False``.

**New Features**

- Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics.
Expand Down
3 changes: 2 additions & 1 deletion examples/deepseek/quantize_to_nvfp4.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from modelopt.torch.utils.serialization import safe_load
# Adapted from https://github.com/deepseek-ai/DeepSeek-V3/blob/2f7b80eecebf3d1c84da5a0d465f6639ea175012/inference/fp8_cast_bf16.py
# MIT License

Expand Down Expand Up @@ -98,7 +99,7 @@ def remove_quantization_config_from_original_config(export_dir: str) -> None:

def load_and_preprocess_state_dict(modelopt_state_root, world_size=8):
state_dict_list = [
torch.load(f"{modelopt_state_root}/amax_dict_rank{rank}-mp{world_size}.pt")
safe_load(f"{modelopt_state_root}/amax_dict_rank{rank}-mp{world_size}.pt")
for rank in range(world_size)
]

Expand Down
3 changes: 2 additions & 1 deletion examples/llm_eval/modeling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from modelopt.torch.utils.serialization import safe_load
# Adapted from https://github.com/declare-lab/instruct-eval/blob/720e66f627369266ed1cfd74426666ec37e524bc/modeling.py

# MIT License
Expand Down Expand Up @@ -428,7 +429,7 @@ def load_quant(

model.load_state_dict(safe_load(checkpoint), strict=False)
else:
model.load_state_dict(torch.load(checkpoint), strict=False)
model.load_state_dict(safe_load(checkpoint), strict=False)

if eval:
quant.make_quant_attn(model)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from modelopt.torch.utils.serialization import safe_load
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
Expand Down Expand Up @@ -219,7 +220,7 @@ async def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id:
if not trtllm_dumped_file.exists():
return False
with open(trtllm_dumped_file, "rb") as f:
trtllm_dumped = torch.load(f)
trtllm_dumped = safe_load(f)
assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, (
"TRTLLM dumped should be a list with one element"
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from modelopt.torch.utils.serialization import safe_load
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
Expand Down Expand Up @@ -60,7 +61,7 @@ def main(args: argparse.Namespace) -> None:
)

for i, file in enumerate(sampled_files):
data = torch.load(file)
data = safe_load(file)
expected_keys = [
"input_ids",
"hidden_states",
Expand Down
3 changes: 2 additions & 1 deletion examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
ShardedDataset,
VisionLanguageDataCollator,
)
from modelopt.torch.utils.serialization import safe_load

try:
import wandb
Expand Down Expand Up @@ -71,7 +72,7 @@ def __len__(self):

def __getitem__(self, i) -> dict[str, torch.Tensor]:
try:
offline_data = torch.load(self.dumped_files[i])
offline_data = safe_load(self.dumped_files[i])
except Exception as e:
print(
f"[ERROR] Failed to load file at index={i}, "
Expand Down
3 changes: 2 additions & 1 deletion examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from modelopt.torch.utils.serialization import safe_load
# Adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/3783d18/train.py

# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
Expand Down Expand Up @@ -234,7 +235,7 @@ def train():
raise FileNotFoundError(
f"Draft vocab cache provided but not found: {data_args.draft_vocab_cache}"
)
model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache)
model.eagle_module.d2t = safe_load(data_args.draft_vocab_cache)
print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.")
else:
raise Exception(f"{training_args.mode} is not supported!")
Expand Down
3 changes: 2 additions & 1 deletion examples/vllm_serve/fakequant_worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from modelopt.torch.utils.serialization import safe_load
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
Expand Down Expand Up @@ -284,7 +285,7 @@ def calibrate_loop(model: Any = None) -> None:
amax_file_path = quant_config["amax_file_path"]
if amax_file_path:
print(f"Loading amax values from {amax_file_path}")
saved_amax_dict = torch.load(amax_file_path)
saved_amax_dict = safe_load(amax_file_path)
# convert amax keys to vLLM format
if hasattr(self.model_runner.model, "hf_to_vllm_mapper"):
saved_amax_dict = self.model_runner.model.hf_to_vllm_mapper.apply_dict(saved_amax_dict)
Expand Down
10 changes: 5 additions & 5 deletions modelopt/torch/export/distribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import torch

from modelopt.torch.utils import distributed as dist
from modelopt.torch.utils.serialization import safe_load

from .model_config_utils import (
model_config_from_dict,
Expand All @@ -41,7 +42,7 @@ class NFSWorkspace:
communication nor barrier. It is users' responsibility to synchronize
all ranks (local and remove processes).

This implementation uses `torch.save` and `torch.load` for serialization.
This implementation uses `torch.save` and `safe_load` for serialization.

Args:
workspace_path: the path to the NFS directory for postprocess cross rank communication.
Expand Down Expand Up @@ -91,8 +92,7 @@ def read_configs_and_weights_from_rank(
raise ValueError("NFSWorkspace is not initialized!")
state_path = self._get_state_path(target_rank)
if state_path.exists():
# Security NOTE: weights_only=False is used here on ModelOpt-generated ckpt, not on untrusted user input
state = torch.load(state_path, map_location="cpu", weights_only=False)
state = safe_load(state_path, map_location="cpu")
return state["config"], state["weight"]
else:
return None, None
Expand Down Expand Up @@ -157,7 +157,7 @@ def get_tensors_parallel(tensor: torch.Tensor, ranks: list[int], group=None):
tensors.append(tensor)
else:
shm = SharedMemory(name=f"rank_{rank}", create=False)
shared_tensor = torch.load(BytesIO(shm.buf))
shared_tensor = safe_load(BytesIO(shm.buf))
tensors.append(shared_tensor)
shm_readers.append(shm)
try:
Expand Down Expand Up @@ -276,7 +276,7 @@ def _get_weights_nbytes(weights_dict: dict[str, torch.Tensor]):

if len_json != 0:
config_dict = json.loads(shm.buf[8 : 8 + len_json].tobytes().decode())
weights = torch.load(BytesIO(shm.buf[8 + len_json :]))
weights = safe_load(BytesIO(shm.buf[8 + len_json :]))
restore_model_config(config_dict, weights)
config = model_config_from_dict(config_dict)

Expand Down
28 changes: 27 additions & 1 deletion modelopt/torch/export/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,5 +631,31 @@ def num_kv_heads(self):

@property
def hidden_act(self):
"""Returns the hidden_act of the model."""
\"\"\"Returns the hidden_act of the model.\"\"\"
return self.layers[0].mlp.hidden_act


# Register all config classes as safe globals
from modelopt.torch.utils.serialization import add_modelopt_safe_globals

add_modelopt_safe_globals(
[
EmbeddingConfig,
LayernormConfig,
LinearConfig,
LinearActConfig,
ConvConfig,
QKVConfig,
RelativeAttentionTableConfig,
AttentionConfig,
MLPConfig,
ExpertConfig,
RgLruConfig,
RecurrentConfig,
MOEConfig,
DecoderLayerConfig,
MedusaHeadConfig,
ModelConfig,
]
)

8 changes: 8 additions & 0 deletions modelopt/torch/opt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from collections.abc import Callable, ItemsView, Iterator, KeysView, ValuesView
from typing import Any, TypeAlias

import torch
from pydantic import (
BaseModel,
Field,
Expand Down Expand Up @@ -65,6 +66,13 @@ class ModeloptBaseConfig(BaseModel):

model_config = PyDanticConfigDict(extra="forbid", validate_assignment=True)

def __init_subclass__(cls, **kwargs: Any) -> None:
"""Register the config class as a safe global for torch serialization."""
super().__init_subclass__(**kwargs)

# Register the class as a safe global for torch serialization
torch.serialization.add_safe_globals([cls])

def model_dump(self, **kwargs):
"""Dump the config to a dictionary with aliases and no warnings by default."""
kwargs = {"by_alias": True, "warnings": False, **kwargs}
Expand Down
15 changes: 7 additions & 8 deletions modelopt/torch/opt/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from modelopt import __version__
from modelopt.torch.utils import ModelLike, init_model_from_model_like, unwrap_model
from modelopt.torch.utils.serialization import safe_load

from .config import ConfigDict, ModeloptBaseConfig
from .mode import (
Expand Down Expand Up @@ -523,12 +524,8 @@ def load_modelopt_state(modelopt_state_path: str | os.PathLike, **kwargs) -> dic
Returns:
A modelopt state dictionary describing the modifications to the model.
"""
# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
kwargs.setdefault("weights_only", False)
kwargs.setdefault("map_location", "cpu")
# TODO: Add some validation to ensure the file is a valid modelopt state file.
modelopt_state = torch.load(modelopt_state_path, **kwargs)
return modelopt_state
return safe_load(modelopt_state_path, **kwargs)


def restore_from_modelopt_state(
Expand All @@ -550,7 +547,9 @@ def restore_from_modelopt_state(

# Restore the previously saved modelopt state followed by model weights
mto.restore_from_modelopt_state(model, modelopt_state_path="modelopt_state.pt")
model.load_state_dict(torch.load("model_weights.pt"), ...) # Load the model weights
model.load_state_dict(
torch.load("model_weights.pt", weights_only=True), ...
) # Load the model weights

If you want to restore the model weights and the modelopt state with saved scales, please use
:meth:`mto.restore()<modelopt.torch.opt.conversion.restore>`.
Expand Down Expand Up @@ -628,8 +627,8 @@ def restore(model: ModelLike, f: str | os.PathLike | BinaryIO, **kwargs) -> nn.M

# load checkpoint
kwargs.setdefault("map_location", "cpu")
kwargs.setdefault("weights_only", False)
objs = torch.load(f, **kwargs)
# Security NOTE: weights_only=True is used here on ModelOpt-generated checkpoints
objs = safe_load(f, **kwargs)

# restore model architecture
model_restored = restore_from_modelopt_state(model, objs["modelopt_state"])
Expand Down
23 changes: 19 additions & 4 deletions modelopt/torch/opt/plugins/mcore_dist_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pathlib import Path
from typing import Any

import megatron.core as mcore
import torch
from megatron.core import dist_checkpointing, mpu
from megatron.core.dist_checkpointing.serialization import get_default_load_sharded_strategy
Expand All @@ -32,6 +33,7 @@
import modelopt.torch.opt as mto
import modelopt.torch.utils.distributed as dist
from modelopt.torch.utils.network import SUPPORTED_WRAPPERS
from modelopt.torch.utils.serialization import safe_load

SUPPORTED_WRAPPERS[Float16Module] = "module"

Expand Down Expand Up @@ -131,6 +133,14 @@ def save_sharded_modelopt_state(
os.makedirs(modelopt_checkpoint_name, exist_ok=True)
modelopt_state = copy.deepcopy(mto.modelopt_state(model[0]))
remove_per_module_state(modelopt_state)

# Persist metadata for MCore version and sharded layout
modelopt_state["mcore_metadata"] = {
"version": mcore.__version__,
"singleton_local_shards": True, # Consistent with _MegatronMLP.sharded_state_dict()
"sharded_strategy": sharded_strategy,
}

dist_checkpointing.save(modelopt_state, modelopt_checkpoint_name, sharded_strategy)


Expand All @@ -156,7 +166,8 @@ def _load_extra_state_from_sharded_checkpoint(
is set to `True` (was not set before) in megatron-core-0.15.0. This flag affects the
sharded state_dict format and must be consistent between saving and loading.
"""
sharded_state_dict = model.sharded_state_dict(prefix=prefix)
# sharded_state_dict() should be called with metadata if provided to maintain consistency
sharded_state_dict = model.sharded_state_dict(prefix=prefix, metadata=metadata)
extra_sharded_state_dict = {k: v for k, v in sharded_state_dict.items() if "_extra_state" in k}
extra_state_dict = dist_checkpointing.load(
extra_sharded_state_dict,
Expand Down Expand Up @@ -203,11 +214,15 @@ def restore_sharded_modelopt_state(
return

# Loading the common modelopt_state (replicated on all ranks)
# Security NOTE: weights_only=False is used here on NVIDIA-generated file, not on untrusted user input
common_modelopt_state = torch.load(
modelopt_checkpoint_name + "/" + COMMON_STATE_FNAME, weights_only=False
common_modelopt_state = safe_load(
modelopt_checkpoint_name + "/" + COMMON_STATE_FNAME, map_location="cpu"
)

# Try to retrieve metadata from checkpoint
mcore_metadata = common_modelopt_state.get("mcore_metadata", None)
if metadata is None:
metadata = mcore_metadata

modelopt_load_version = common_modelopt_state["modelopt_version"]

print(f"nvidia-modelopt ckpt/inst version: {modelopt_load_version}/{modelopt.__version__}")
Expand Down
15 changes: 8 additions & 7 deletions modelopt/torch/opt/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@
"""Support quantization and save/resore for Megatron."""

import contextlib
import pickle # nosec
import types
from io import BytesIO
from typing import Any

import megatron.core.transformer.mlp as megatron_mlp
import regex as re
import torch
from megatron.core.parallel_state import get_data_parallel_group

from modelopt.torch.utils.serialization import safe_load

from ..dynamic import DynamicModule


Expand Down Expand Up @@ -82,8 +84,10 @@ def _modelopt_get_extra_state(self):

# Serialize state into byte tensor
torch.cuda.synchronize()
state_serialized = bytearray(pickle.dumps(extra_state)) # nosec
state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
# Use torch.save for serialization to match safe_load
buffer = BytesIO()
torch.save(extra_state, buffer)
state_serialized = torch.frombuffer(buffer.getvalue(), dtype=torch.uint8)

return state_serialized

Expand All @@ -102,10 +106,7 @@ def _modelopt_set_extra_state(self, state: Any):
if state.numel() == 0:
return
# Default format: byte tensor with pickled data
#
# TODO: possible deserialization improvement
# https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/serialization.py
extra_state = pickle.loads(state.detach().cpu().numpy().tobytes()) # nosec
extra_state = safe_load(state.detach().cpu().numpy().tobytes(), map_location="cpu")
else:
raise RuntimeError("Unsupported extra_state format.")

Expand Down
10 changes: 7 additions & 3 deletions modelopt/torch/opt/plugins/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from peft import PeftModel

from modelopt.torch.utils import get_unwrapped_name, print_rank_0
from modelopt.torch.utils.serialization import safe_load

from ..conversion import ModeloptStateManager, modelopt_state, restore_from_modelopt_state
from .huggingface import register_for_patching
Expand Down Expand Up @@ -83,11 +84,14 @@ def _new_load_adapter(self, model_id, adapter_name, *args, **kwargs):
if os.path.isfile(_get_quantizer_state_save_path(model_id)):
from modelopt.torch.quantization.nn import TensorQuantizer

# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
quantizer_state_dict = torch.load(
_get_quantizer_state_save_path(model_id), map_location="cpu", weights_only=False
# Security NOTE: weights_only=False is required here because quantizer_state_dict contains
# custom ModelOpt types like QTensorWrapper which are pickled. These files are expected
# to be generated by the PEFT plugin, not from untrusted sources.
quantizer_state_dict = safe_load(
_get_quantizer_state_save_path(model_id), map_location=\"cpu\", weights_only=False
)
for name, module in self.named_modules():

if isinstance(module, TensorQuantizer):
module.load_state_dict(quantizer_state_dict[get_unwrapped_name(name, self)])

Expand Down
Loading