Skip to content

[Security] Enable torch.load(weights_only=True) for secure checkpoint loading + trust_remote_code fix#1181

Merged
kevalmorabia97 merged 3 commits intomainfrom
kmorabia/secure-ckpt-loading
Apr 7, 2026
Merged

[Security] Enable torch.load(weights_only=True) for secure checkpoint loading + trust_remote_code fix#1181
kevalmorabia97 merged 3 commits intomainfrom
kmorabia/secure-ckpt-loading

Conversation

@kevalmorabia97
Copy link
Copy Markdown
Collaborator

@kevalmorabia97 kevalmorabia97 commented Apr 6, 2026

What does this PR do?

Testing

CICD tests ran

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: ✅

Additional Information

NVBug: 5999336

Summary by CodeRabbit

  • New Features

    • Added safe checkpoint save/load helpers and a --trust_remote_code CLI flag in examples to control remote-code loading.
  • Bug Fixes

    • Checkpoint loading now defaults to safer, weights-only semantics to reduce arbitrary-code exposure.
  • Documentation

    • CHANGELOG updated with security guidance and opt-in procedure for unsafe checkpoint loading.
  • Tests

    • New unit tests validating the safe-load behavior.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 6, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 6, 2026

📝 Walkthrough

Walkthrough

Introduces safe checkpoint helpers (safe_save/safe_load) and registers many ModelOpt classes as PyTorch safe globals; migrates internal checkpoint load/save calls to these helpers (defaulting to weights_only=True) and adds a configurable --trust_remote_code flag across numerous example scripts and loaders.

Changes

Cohort / File(s) Summary
Changelog & README
CHANGELOG.rst, examples/windows/onnx_ptq/whisper/README.md
Documented secure checkpoint-loading behavior and removed trust_remote_code=True from a README dataset example.
New serialization API
modelopt/torch/utils/serialization.py, modelopt/torch/utils/__init__.py
Added _sanitize_for_save, safe_save, safe_load; default safe_load to weights_only=True, register slice as safe global, and re-export serialization symbols.
Safe globals registration
modelopt/torch/export/model_config.py, modelopt/torch/opt/config.py, modelopt/torch/opt/hparam.py, modelopt/torch/quantization/qtensor/base_qtensor.py, modelopt/torch/quantization/calib/calibrator.py
Register config, hparam, quantizer, and calibrator subclasses with torch.serialization.add_safe_globals (module-level or via __init_subclass__).
Replace torch.load/save with safe_ helpers*
modelopt/torch/export/distribute.py, modelopt/torch/opt/conversion.py, modelopt/torch/opt/plugins/mcore_dist_checkpointing.py, modelopt/torch/opt/plugins/peft.py, modelopt/torch/opt/plugins/megatron.py, modelopt/torch/opt/searcher.py, modelopt/torch/prune/importance_hooks/*.py, examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py, tests/gpu/torch/export/test_vllm_fakequant_hf_export.py, tests/unit/torch/quantization/test_autoquant.py
Replaced direct torch.load(..., weights_only=False) / torch.save(...) calls with safe_load / safe_save, removed explicit weights_only=False defaults, and adapted in-memory serialization flows.
Example scripts: trust_remote_code flag
examples/gpt-oss/convert_oai_mxfp4_weight_only.py, examples/llm_autodeploy/run_auto_quantize.py, examples/llm_eval/..., examples/llm_ptq/..., examples/llm_qad/data_utils/download_dataset.py, examples/speculative_decoding/..., examples/speculative_decoding/scripts/..., examples/windows/accuracy_benchmark/...
Added --trust_remote_code CLI flags and threaded trust_remote_code into model/tokenizer/processor loading; replaced hardcoded trust_remote_code=True with configurable arguments.
Dataset loader adjustments
modelopt/torch/utils/speech_dataset_utils.py, examples/windows/onnx_ptq/whisper/whisper_onnx_quantization.py, examples/windows/onnx_ptq/whisper/whisper_optimum_ort_inference.py
Removed explicit trust_remote_code=True from datasets.load_dataset calls.
Utilities & examples wiring
examples/llm_ptq/vlm_utils.py, examples/llm_ptq/example_utils.py, examples/windows/torch_onnx/diffusers/...
Propagated trust_remote_code parameters through preview/generation helpers and updated function signatures to accept the flag.
Tests
tests/unit/torch/utils/test_serialization.py, other test updates
Added unit tests for safe_load covering bytes/path inputs and updated tests to use safe_load instead of torch.load(..., weights_only=False).

Sequence Diagram(s)

mermaid
sequenceDiagram
participant Caller
participant SafeUtils as "modelopt.torch.utils.serialization.safe_load"
participant File as "File / BytesIO"
participant TorchLoad as "torch.load"
Caller->>SafeUtils: safe_load(f, **kwargs)
SafeUtils->>File: wrap bytes/bytearray into BytesIO (if needed)
SafeUtils->>TorchLoad: torch.load(fileobj, weights_only=True, **kwargs)
TorchLoad-->>SafeUtils: deserialized object
SafeUtils-->>Caller: return object

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 63.64% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly describes the main security-focused changes: enabling safe checkpoint loading with weights_only=True and removing hard-coded trust_remote_code=True defaults.
Security Anti-Patterns ✅ Passed No security anti-patterns found. safe_load/safe_save with weights_only=True defaults implemented; trust_remote_code hardcoded=True eliminated; dataclasses registered with add_safe_globals; no eval/exec/nosec patterns present.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch kmorabia/secure-ckpt-loading

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 6, 2026

PR Preview Action v1.8.1
Preview removed because the pull request was closed.
2026-04-07 19:06 UTC

@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/secure-ckpt-loading branch from da30b80 to eadd16c Compare April 6, 2026 15:00
@codecov
Copy link
Copy Markdown

codecov bot commented Apr 6, 2026

Codecov Report

❌ Patch coverage is 90.14085% with 7 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.71%. Comparing base (80d2f02) to head (3081b2f).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
...h/prune/importance_hooks/compare_module_outputs.py 0.00% 4 Missing ⚠️
modelopt/torch/export/distribute.py 50.00% 1 Missing ⚠️
...odelopt/torch/prune/importance_hooks/base_hooks.py 50.00% 1 Missing ⚠️
modelopt/torch/utils/serialization.py 96.42% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1181      +/-   ##
==========================================
+ Coverage   74.77%   76.71%   +1.94%     
==========================================
  Files         351      352       +1     
  Lines       40289    40338      +49     
==========================================
+ Hits        30125    30947     +822     
+ Misses      10164     9391     -773     
Flag Coverage Δ
examples 45.26% <74.64%> (+5.02%) ⬆️
gpu 56.93% <83.09%> (-0.13%) ⬇️
unit 54.88% <78.87%> (+0.04%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/secure-ckpt-loading branch from eadd16c to 7c662f6 Compare April 6, 2026 16:10
@kevalmorabia97 kevalmorabia97 changed the title [Security] Enable torch.load(weights_only=True) by default for secure checkpoint loading [Security] Enable torch.load(weights_only=True) for secure checkpoint loading + trust_remote_code fix Apr 6, 2026
@kevalmorabia97 kevalmorabia97 marked this pull request as ready for review April 6, 2026 16:20
@kevalmorabia97 kevalmorabia97 requested review from a team as code owners April 6, 2026 16:20
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/secure-ckpt-loading branch from 63554d9 to 8b0c7c2 Compare April 6, 2026 16:30
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
examples/windows/onnx_ptq/whisper/whisper_optimum_ort_inference.py (1)

87-122: ⚠️ Potential issue | 🟠 Major

Guard summary output when WER test is disabled.

With default --run_wer_test off, Line 118 accesses wer_result (and Lines 120-121 use references/predictions) before assignment, causing a runtime failure.

💡 Suggested fix
-    if args.run_wer_test:
+    if args.run_wer_test:
         librispeech_test_clean = load_dataset("librispeech_asr", "clean", split="test")
@@
-        wer = load("wer")
-        wer_result = wer.compute(references=references, predictions=predictions)
-
-    print(
-        f"\n## DONE ## - wer = {wer_result}, wer% = {wer_result * 100}, accuracy% = {(1 - wer_result) * 100},"
-        f"\n  total-time = {time.time() - start_time} seconds,"
-        f"\n  num-distinct-inputs={len(set(references))},"
-        f"\n  len-reference={len(references)}, len-predictions={len(predictions)}\n\n"
-    )
+        wer = load("wer")
+        wer_result = wer.compute(references=references, predictions=predictions)
+        print(
+            f"\n## DONE ## - wer = {wer_result}, wer% = {wer_result * 100}, accuracy% = {(1 - wer_result) * 100},"
+            f"\n  total-time = {time.time() - start_time} seconds,"
+            f"\n  num-distinct-inputs={len(set(references))},"
+            f"\n  len-reference={len(references)}, len-predictions={len(predictions)}\n\n"
+        )
+    else:
+        print(f"\n## DONE ## - total-time = {time.time() - start_time} seconds\n")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/windows/onnx_ptq/whisper/whisper_optimum_ort_inference.py` around
lines 87 - 122, The final summary print uses wer_result, references and
predictions even when args.run_wer_test is false, causing an uninitialized
variable error; fix by guarding the computation and the print with the same
condition (args.run_wer_test) or initialize defaults: compute wer_result only
when load("wer") is run and only call the print that references
wer_result/references/predictions inside the if args.run_wer_test block (or set
safe defaults for wer_result, references, predictions before the if) so
model.generate/processor decode code paths and the final print do not reference
undefined symbols.
examples/llm_eval/modeling.py (1)

436-436: ⚠️ Potential issue | 🟡 Minor

Pre-existing security issue: torch.load without weights_only=True.

Line 436 uses torch.load(checkpoint) without specifying weights_only=True. While this is pre-existing code not modified by this PR, it should be addressed as part of the security hardening effort. As per coding guidelines: "Do not use torch.load(..., weights_only=False) unless a documented exception is provided."

🔒 Suggested fix
-        model.load_state_dict(torch.load(checkpoint), strict=False)
+        from modelopt.torch.utils import safe_load
+        model.load_state_dict(safe_load(checkpoint), strict=False)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_eval/modeling.py` at line 436, The call to torch.load when
loading the checkpoint into model.load_state_dict uses torch.load(checkpoint)
which can deserialize arbitrary objects; change it to pass weights_only=True
(i.e., call torch.load(checkpoint, weights_only=True)) so only tensor weights
are deserialized before calling model.load_state_dict(..., strict=False); update
the invocation near model.load_state_dict to use torch.load with
weights_only=True and, if necessary, add a short compatibility guard or comment
if running on older PyTorch versions that do not support weights_only.
🧹 Nitpick comments (4)
modelopt/torch/utils/speech_dataset_utils.py (1)

50-52: Security improvement: Removal of hardcoded trust_remote_code=True is correct.

This change properly addresses the security concern by removing the hardcoded trust_remote_code=True, allowing the default safe behavior (False).

For full alignment with the coding guideline pattern ("let the caller decide via a parameter; default to False"), consider exposing this as an optional parameter for callers who may need to load datasets that require remote code execution in the future.

♻️ Optional: Expose parameter to callers
-def _get_speech_dataset(dataset_name: str, num_samples: int):
+def _get_speech_dataset(dataset_name: str, num_samples: int, trust_remote_code: bool = False):
     """Load a portion of train dataset with the dataset name and a given size.

     Args:
         dataset_name: Name of the dataset to load.
         num_samples: Number of samples to load from the dataset.
+        trust_remote_code: Whether to trust remote code when loading the dataset.

     Returns:
         A hugging face Dataset.
     """
     # Load the dataset
     if dataset_name in SUPPORTED_SPEECH_DATASET_CONFIG:
         from datasets import load_dataset

         # Use streaming can reduce the downloading time for large datasets
         dataset = load_dataset(
-            **SUPPORTED_SPEECH_DATASET_CONFIG[dataset_name]["config"], streaming=True
+            **SUPPORTED_SPEECH_DATASET_CONFIG[dataset_name]["config"],
+            streaming=True,
+            trust_remote_code=trust_remote_code,
         )

And similarly propagate the parameter through get_speech_dataset_dataloader.

As per coding guidelines: "Do not hardcode trust_remote_code=True when loading Hugging Face Transformers models. Let the caller decide via a parameter; default to False."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/speech_dataset_utils.py` around lines 50 - 52, The
dataset loading call uses load_dataset without exposing trust_remote_code; add
an optional parameter (e.g., trust_remote_code: bool = False) to the function
that contains the load_dataset call and pass that parameter into load_dataset
(using SUPPORTED_SPEECH_DATASET_CONFIG[dataset_name]["config"] plus
trust_remote_code=trust_remote_code), and also propagate this new parameter
through the public helper get_speech_dataset_dataloader so callers can opt into
remote code execution while the default remains False.
tests/unit/torch/utils/test_serialization.py (1)

48-71: LGTM - Basic functionality tests.

Consider adding tests for safe_save and the _sanitize_for_save function to verify container subclass handling (e.g., defaultdictdict conversion).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/utils/test_serialization.py` around lines 48 - 71, Add unit
tests that exercise safe_save and the internal _sanitize_for_save to ensure
container subclasses are converted to plain containers (e.g.,
collections.defaultdict -> dict) before serialization; create a test that builds
a state containing a defaultdict and other subclassed containers, call safe_save
to a BytesIO or temp file and then safe_load (or inspect the serialized result)
to assert the loaded/serialized types are plain dict/list/tuple as appropriate;
reference the safe_save and _sanitize_for_save functions to locate the
serialization logic to cover this behavior.
examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py (1)

617-617: Consider using safe_save for consistency.

Line 617 uses torch.save directly while other parts of the codebase use safe_save. While torch.save isn't a security risk (loading is the concern), using safe_save would ensure the saved state is compatible with weights_only=True loading.

♻️ Optional consistency fix
-                torch.save(modelopt_state, str(modelopt_path))
+                from modelopt.torch.utils import safe_save
+                safe_save(modelopt_state, str(modelopt_path))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py`
at line 617, Replace the direct torch.save call used to write modelopt_state to
modelopt_path with the project's safe_save utility to maintain consistency and
ensure the saved state is compatible with weights_only=True loading; locate the
occurrence where torch.save(modelopt_state, str(modelopt_path)) is called
(referencing the modelopt_state and modelopt_path symbols) and swap it to call
safe_save(modelopt_state, modelopt_path) or the project's equivalent safe_save
API, ensuring any required imports or path type adjustments are added.
modelopt/torch/utils/serialization.py (1)

28-48: Minor redundancy in list subclass handling.

Lines 43-45 have identical behavior in both branches - both return sanitized_list. The condition type(obj) is list check and the else branch produce the same result.

♻️ Simplify list handling
     if isinstance(obj, list):
         sanitized_list = [_sanitize_for_save(v) for v in obj]
-        if type(obj) is list:
-            return sanitized_list
         return sanitized_list
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/serialization.py` around lines 28 - 48, The
list-handling in _sanitize_for_save is redundant: it computes sanitized_list
then checks "if type(obj) is list" but both branches return sanitized_list;
remove the conditional and simply return sanitized_list unconditionally to
simplify the function (refer to the _sanitize_for_save function and the
sanitized_list variable and the existing type(obj) is list check to locate the
code to change).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/gpt-oss/convert_oai_mxfp4_weight_only.py`:
- Line 125: The kwargs dictionary passed to AutoModelForCausalLM.from_pretrained
uses an incorrect key "dtype" so the model's precision is ignored; update the
kwargs declaration (the variable named kwargs) to use "torch_dtype" instead of
"dtype" (preserving "device_map": "auto" and "trust_remote_code":
args.trust_remote_code) so AutoModelForCausalLM.from_pretrained receives the
proper parameter.

In `@examples/llm_autodeploy/run_auto_quantize.py`:
- Around line 207-211: The help text for the argparse flag "--trust_remote_code"
contains a typo ("trust_remotecode"); update the help string in the
parser.add_argument call that defines "--trust_remote_code" to read "Set
trust_remote_code for Huggingface models and tokenizers" so the option name and
help text match exactly.

In `@examples/llm_qad/data_utils/download_dataset.py`:
- Around line 162-166: Fix the typo in the argparse help string for the
"--trust_remote_code" option: update the help text in the p.add_argument call
(the argument name "--trust_remote_code") to read "Set trust_remote_code for
Huggingface models and tokenizers" instead of "trust_remotecode".

In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py`:
- Line 91: The help string for the argument referring to remote code contains a
typo ("trust_remotecode"); update the help text to "trust_remote_code" for
consistency, i.e., locate the add_argument call that defines the
trust_remote_code flag (the help= parameter on that argument in
compute_hidden_states_hf.py) and change the help value to use
"trust_remote_code".

In `@examples/speculative_decoding/scripts/export_hf_checkpoint.py`:
- Line 32: The CLI flag parser currently uses
parser.add_argument("--trust_remote_code", type=bool, default=False, ...) which
misparses strings (bool("False") == True); change this to use a boolean flag
with action="store_true" (e.g., parser.add_argument("--trust_remote_code",
action="store_true", default=False, help=...)) so that passing
--trust_remote_code sets True and omitting it keeps False; update the
parser.add_argument call for the trust_remote_code option accordingly.

In `@modelopt/torch/prune/importance_hooks/base_hooks.py`:
- Line 737: The call to safe_load in base_hooks.py (where activation_data =
safe_load(activation_file)) can load tensors onto their original device; change
the call to explicitly pass map_location="cpu" so tensors are deserialized onto
CPU for safe aggregation across ranks and CPU-only environments; update the
safe_load(activation_file) invocation to safe_load(activation_file,
map_location="cpu") (or the equivalent parameter name used by your deserializer)
so activation_data is guaranteed to be on CPU before further processing.

---

Outside diff comments:
In `@examples/llm_eval/modeling.py`:
- Line 436: The call to torch.load when loading the checkpoint into
model.load_state_dict uses torch.load(checkpoint) which can deserialize
arbitrary objects; change it to pass weights_only=True (i.e., call
torch.load(checkpoint, weights_only=True)) so only tensor weights are
deserialized before calling model.load_state_dict(..., strict=False); update the
invocation near model.load_state_dict to use torch.load with weights_only=True
and, if necessary, add a short compatibility guard or comment if running on
older PyTorch versions that do not support weights_only.

In `@examples/windows/onnx_ptq/whisper/whisper_optimum_ort_inference.py`:
- Around line 87-122: The final summary print uses wer_result, references and
predictions even when args.run_wer_test is false, causing an uninitialized
variable error; fix by guarding the computation and the print with the same
condition (args.run_wer_test) or initialize defaults: compute wer_result only
when load("wer") is run and only call the print that references
wer_result/references/predictions inside the if args.run_wer_test block (or set
safe defaults for wer_result, references, predictions before the if) so
model.generate/processor decode code paths and the final print do not reference
undefined symbols.

---

Nitpick comments:
In
`@examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py`:
- Line 617: Replace the direct torch.save call used to write modelopt_state to
modelopt_path with the project's safe_save utility to maintain consistency and
ensure the saved state is compatible with weights_only=True loading; locate the
occurrence where torch.save(modelopt_state, str(modelopt_path)) is called
(referencing the modelopt_state and modelopt_path symbols) and swap it to call
safe_save(modelopt_state, modelopt_path) or the project's equivalent safe_save
API, ensuring any required imports or path type adjustments are added.

In `@modelopt/torch/utils/serialization.py`:
- Around line 28-48: The list-handling in _sanitize_for_save is redundant: it
computes sanitized_list then checks "if type(obj) is list" but both branches
return sanitized_list; remove the conditional and simply return sanitized_list
unconditionally to simplify the function (refer to the _sanitize_for_save
function and the sanitized_list variable and the existing type(obj) is list
check to locate the code to change).

In `@modelopt/torch/utils/speech_dataset_utils.py`:
- Around line 50-52: The dataset loading call uses load_dataset without exposing
trust_remote_code; add an optional parameter (e.g., trust_remote_code: bool =
False) to the function that contains the load_dataset call and pass that
parameter into load_dataset (using
SUPPORTED_SPEECH_DATASET_CONFIG[dataset_name]["config"] plus
trust_remote_code=trust_remote_code), and also propagate this new parameter
through the public helper get_speech_dataset_dataloader so callers can opt into
remote code execution while the default remains False.

In `@tests/unit/torch/utils/test_serialization.py`:
- Around line 48-71: Add unit tests that exercise safe_save and the internal
_sanitize_for_save to ensure container subclasses are converted to plain
containers (e.g., collections.defaultdict -> dict) before serialization; create
a test that builds a state containing a defaultdict and other subclassed
containers, call safe_save to a BytesIO or temp file and then safe_load (or
inspect the serialized result) to assert the loaded/serialized types are plain
dict/list/tuple as appropriate; reference the safe_save and _sanitize_for_save
functions to locate the serialization logic to cover this behavior.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 972acaa8-498b-4f10-ac29-85d073133431

📥 Commits

Reviewing files that changed from the base of the PR and between 4a5ef01 and 63554d9.

📒 Files selected for processing (38)
  • CHANGELOG.rst
  • examples/gpt-oss/convert_oai_mxfp4_weight_only.py
  • examples/llm_autodeploy/run_auto_quantize.py
  • examples/llm_eval/lm_eval_hf.py
  • examples/llm_eval/modeling.py
  • examples/llm_ptq/example_utils.py
  • examples/llm_ptq/hf_ptq.py
  • examples/llm_ptq/vlm_utils.py
  • examples/llm_qad/data_utils/download_dataset.py
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
  • examples/speculative_decoding/scripts/ar_validate.py
  • examples/speculative_decoding/scripts/export_hf_checkpoint.py
  • examples/speculative_decoding/scripts/send_conversation_vllm.py
  • examples/windows/accuracy_benchmark/mmlu_benchmark.py
  • examples/windows/accuracy_benchmark/modeling.py
  • examples/windows/onnx_ptq/whisper/README.md
  • examples/windows/onnx_ptq/whisper/whisper_onnx_quantization.py
  • examples/windows/onnx_ptq/whisper/whisper_optimum_ort_inference.py
  • examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py
  • modelopt/torch/export/distribute.py
  • modelopt/torch/export/model_config.py
  • modelopt/torch/opt/config.py
  • modelopt/torch/opt/conversion.py
  • modelopt/torch/opt/hparam.py
  • modelopt/torch/opt/plugins/mcore_dist_checkpointing.py
  • modelopt/torch/opt/plugins/megatron.py
  • modelopt/torch/opt/plugins/peft.py
  • modelopt/torch/opt/searcher.py
  • modelopt/torch/prune/importance_hooks/base_hooks.py
  • modelopt/torch/prune/importance_hooks/compare_module_outputs.py
  • modelopt/torch/quantization/qtensor/base_qtensor.py
  • modelopt/torch/utils/__init__.py
  • modelopt/torch/utils/serialization.py
  • modelopt/torch/utils/speech_dataset_utils.py
  • tests/gpu/torch/export/test_vllm_fakequant_hf_export.py
  • tests/gpu/torch/quantization/test_gptq.py
  • tests/unit/torch/quantization/test_autoquant.py
  • tests/unit/torch/utils/test_serialization.py

@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/secure-ckpt-loading branch from 8b0c7c2 to a517e9e Compare April 6, 2026 16:35
Copy link
Copy Markdown
Collaborator

@shengliangxu shengliangxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall.

Copy link
Copy Markdown
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown
Contributor

@h-guo18 h-guo18 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

examples/speculative_decoding changes LGTM

@vishalpandya1990
Copy link
Copy Markdown
Contributor

Had a quick look at examples/windows changes - looks okay to me.

kevalmorabia97 and others added 3 commits April 6, 2026 22:27
…int loading

Co-authored-by: RinZ27 <222222878+RinZ27@users.noreply.github.com>
Signed-off-by: RinZ27 <222222878+RinZ27@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/secure-ckpt-loading branch from 1a3ae07 to 3081b2f Compare April 7, 2026 05:27
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
modelopt/torch/export/distribute.py (1)

78-78: Consider using safe_save for consistency.

Line 78 uses torch.save while line 95 uses safe_load. For consistency with the PR's secure checkpoint handling, consider updating this to use safe_save:

-        torch.save({"config": config_json, "weight": weights}, self.state_path)
+        safe_save({"config": config_json, "weight": weights}, self.state_path)

This would require adding safe_save to the import at line 28.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/distribute.py` at line 78, Replace the direct call to
torch.save in the checkpoint export with the project's safe_save utility to keep
secure/consistent checkpoint handling: change the call that currently saves
{"config": config_json, "weight": weights} to use safe_save(self.state_path,
{"config": config_json, "weight": weights}) and add safe_save to the imports
alongside the existing imports so the function is available; ensure you
reference the same state_path and payload used by torch.save and keep semantics
identical to the following safe_save usage elsewhere (matching safe_load usage).
modelopt/torch/opt/conversion.py (1)

513-513: Consider using safe_save in save() for consistency.

The save() function uses torch.save while restore() and load_modelopt_state() now use safe_load. For consistency and to ensure the saved format is compatible with weights_only=True loading, consider updating to use safe_save:

-    torch.save(ckpt_dict, f, **kwargs)
+    from modelopt.torch.utils import safe_save
+    safe_save(ckpt_dict, f, **kwargs)

Or add safe_save to the import at line 37.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/opt/conversion.py` at line 513, The save() function currently
calls torch.save(ckpt_dict, f, **kwargs); change it to use safe_save(ckpt_dict,
f, **kwargs) so saved checkpoints are compatible with the existing
safe_load-based restore() and load_modelopt_state() behavior and
weights_only=True loading; also add safe_save to the module imports (where other
torch helpers are imported) so the symbol is available. Ensure you only replace
torch.save with safe_save in the save() implementation and update imports to
include safe_save.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@modelopt/torch/export/distribute.py`:
- Line 78: Replace the direct call to torch.save in the checkpoint export with
the project's safe_save utility to keep secure/consistent checkpoint handling:
change the call that currently saves {"config": config_json, "weight": weights}
to use safe_save(self.state_path, {"config": config_json, "weight": weights})
and add safe_save to the imports alongside the existing imports so the function
is available; ensure you reference the same state_path and payload used by
torch.save and keep semantics identical to the following safe_save usage
elsewhere (matching safe_load usage).

In `@modelopt/torch/opt/conversion.py`:
- Line 513: The save() function currently calls torch.save(ckpt_dict, f,
**kwargs); change it to use safe_save(ckpt_dict, f, **kwargs) so saved
checkpoints are compatible with the existing safe_load-based restore() and
load_modelopt_state() behavior and weights_only=True loading; also add safe_save
to the module imports (where other torch helpers are imported) so the symbol is
available. Ensure you only replace torch.save with safe_save in the save()
implementation and update imports to include safe_save.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 17a54433-cf2b-4f8d-80ef-e391e158845f

📥 Commits

Reviewing files that changed from the base of the PR and between 1a3ae07 and 3081b2f.

📒 Files selected for processing (39)
  • CHANGELOG.rst
  • examples/gpt-oss/convert_oai_mxfp4_weight_only.py
  • examples/llm_autodeploy/run_auto_quantize.py
  • examples/llm_eval/lm_eval_hf.py
  • examples/llm_eval/modeling.py
  • examples/llm_ptq/example_utils.py
  • examples/llm_ptq/hf_ptq.py
  • examples/llm_ptq/vlm_utils.py
  • examples/llm_qad/data_utils/download_dataset.py
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
  • examples/speculative_decoding/scripts/ar_validate.py
  • examples/speculative_decoding/scripts/export_hf_checkpoint.py
  • examples/speculative_decoding/scripts/send_conversation_vllm.py
  • examples/windows/accuracy_benchmark/mmlu_benchmark.py
  • examples/windows/accuracy_benchmark/modeling.py
  • examples/windows/onnx_ptq/whisper/README.md
  • examples/windows/onnx_ptq/whisper/whisper_onnx_quantization.py
  • examples/windows/onnx_ptq/whisper/whisper_optimum_ort_inference.py
  • examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py
  • modelopt/torch/export/distribute.py
  • modelopt/torch/export/model_config.py
  • modelopt/torch/opt/config.py
  • modelopt/torch/opt/conversion.py
  • modelopt/torch/opt/hparam.py
  • modelopt/torch/opt/plugins/mcore_dist_checkpointing.py
  • modelopt/torch/opt/plugins/megatron.py
  • modelopt/torch/opt/plugins/peft.py
  • modelopt/torch/opt/searcher.py
  • modelopt/torch/prune/importance_hooks/base_hooks.py
  • modelopt/torch/prune/importance_hooks/compare_module_outputs.py
  • modelopt/torch/quantization/calib/calibrator.py
  • modelopt/torch/quantization/qtensor/base_qtensor.py
  • modelopt/torch/utils/__init__.py
  • modelopt/torch/utils/serialization.py
  • modelopt/torch/utils/speech_dataset_utils.py
  • tests/gpu/torch/export/test_vllm_fakequant_hf_export.py
  • tests/gpu/torch/quantization/test_gptq.py
  • tests/unit/torch/quantization/test_autoquant.py
  • tests/unit/torch/utils/test_serialization.py
✅ Files skipped from review due to trivial changes (5)
  • examples/windows/onnx_ptq/whisper/README.md
  • tests/gpu/torch/quantization/test_gptq.py
  • CHANGELOG.rst
  • tests/unit/torch/utils/test_serialization.py
  • examples/llm_ptq/example_utils.py
🚧 Files skipped from review as they are similar to previous changes (24)
  • examples/windows/onnx_ptq/whisper/whisper_optimum_ort_inference.py
  • examples/llm_eval/lm_eval_hf.py
  • examples/speculative_decoding/scripts/send_conversation_vllm.py
  • examples/llm_ptq/hf_ptq.py
  • modelopt/torch/utils/init.py
  • examples/speculative_decoding/scripts/export_hf_checkpoint.py
  • examples/speculative_decoding/scripts/ar_validate.py
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
  • modelopt/torch/utils/speech_dataset_utils.py
  • examples/windows/accuracy_benchmark/mmlu_benchmark.py
  • tests/unit/torch/quantization/test_autoquant.py
  • modelopt/torch/opt/plugins/megatron.py
  • modelopt/torch/opt/plugins/peft.py
  • modelopt/torch/quantization/qtensor/base_qtensor.py
  • modelopt/torch/export/model_config.py
  • modelopt/torch/prune/importance_hooks/base_hooks.py
  • examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py
  • modelopt/torch/quantization/calib/calibrator.py
  • examples/windows/onnx_ptq/whisper/whisper_onnx_quantization.py
  • examples/llm_qad/data_utils/download_dataset.py
  • modelopt/torch/utils/serialization.py
  • examples/llm_eval/modeling.py
  • examples/llm_ptq/vlm_utils.py
  • modelopt/torch/opt/plugins/mcore_dist_checkpointing.py

@kevalmorabia97 kevalmorabia97 requested a review from mxinO April 7, 2026 16:25
@kevalmorabia97 kevalmorabia97 merged commit 5dc17df into main Apr 7, 2026
58 of 61 checks passed
@kevalmorabia97 kevalmorabia97 deleted the kmorabia/secure-ckpt-loading branch April 7, 2026 19:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants