[Security] Enable torch.load(weights_only=True) for secure checkpoint loading + trust_remote_code fix#1181
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughIntroduces safe checkpoint helpers ( Changes
Sequence Diagram(s)mermaid Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
da30b80 to
eadd16c
Compare
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1181 +/- ##
==========================================
+ Coverage 74.77% 76.71% +1.94%
==========================================
Files 351 352 +1
Lines 40289 40338 +49
==========================================
+ Hits 30125 30947 +822
+ Misses 10164 9391 -773
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
eadd16c to
7c662f6
Compare
63554d9 to
8b0c7c2
Compare
There was a problem hiding this comment.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
examples/windows/onnx_ptq/whisper/whisper_optimum_ort_inference.py (1)
87-122:⚠️ Potential issue | 🟠 MajorGuard summary output when WER test is disabled.
With default
--run_wer_testoff, Line 118 accesseswer_result(and Lines 120-121 usereferences/predictions) before assignment, causing a runtime failure.💡 Suggested fix
- if args.run_wer_test: + if args.run_wer_test: librispeech_test_clean = load_dataset("librispeech_asr", "clean", split="test") @@ - wer = load("wer") - wer_result = wer.compute(references=references, predictions=predictions) - - print( - f"\n## DONE ## - wer = {wer_result}, wer% = {wer_result * 100}, accuracy% = {(1 - wer_result) * 100}," - f"\n total-time = {time.time() - start_time} seconds," - f"\n num-distinct-inputs={len(set(references))}," - f"\n len-reference={len(references)}, len-predictions={len(predictions)}\n\n" - ) + wer = load("wer") + wer_result = wer.compute(references=references, predictions=predictions) + print( + f"\n## DONE ## - wer = {wer_result}, wer% = {wer_result * 100}, accuracy% = {(1 - wer_result) * 100}," + f"\n total-time = {time.time() - start_time} seconds," + f"\n num-distinct-inputs={len(set(references))}," + f"\n len-reference={len(references)}, len-predictions={len(predictions)}\n\n" + ) + else: + print(f"\n## DONE ## - total-time = {time.time() - start_time} seconds\n")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/windows/onnx_ptq/whisper/whisper_optimum_ort_inference.py` around lines 87 - 122, The final summary print uses wer_result, references and predictions even when args.run_wer_test is false, causing an uninitialized variable error; fix by guarding the computation and the print with the same condition (args.run_wer_test) or initialize defaults: compute wer_result only when load("wer") is run and only call the print that references wer_result/references/predictions inside the if args.run_wer_test block (or set safe defaults for wer_result, references, predictions before the if) so model.generate/processor decode code paths and the final print do not reference undefined symbols.examples/llm_eval/modeling.py (1)
436-436:⚠️ Potential issue | 🟡 MinorPre-existing security issue:
torch.loadwithoutweights_only=True.Line 436 uses
torch.load(checkpoint)without specifyingweights_only=True. While this is pre-existing code not modified by this PR, it should be addressed as part of the security hardening effort. As per coding guidelines: "Do not usetorch.load(..., weights_only=False)unless a documented exception is provided."🔒 Suggested fix
- model.load_state_dict(torch.load(checkpoint), strict=False) + from modelopt.torch.utils import safe_load + model.load_state_dict(safe_load(checkpoint), strict=False)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/llm_eval/modeling.py` at line 436, The call to torch.load when loading the checkpoint into model.load_state_dict uses torch.load(checkpoint) which can deserialize arbitrary objects; change it to pass weights_only=True (i.e., call torch.load(checkpoint, weights_only=True)) so only tensor weights are deserialized before calling model.load_state_dict(..., strict=False); update the invocation near model.load_state_dict to use torch.load with weights_only=True and, if necessary, add a short compatibility guard or comment if running on older PyTorch versions that do not support weights_only.
🧹 Nitpick comments (4)
modelopt/torch/utils/speech_dataset_utils.py (1)
50-52: Security improvement: Removal of hardcodedtrust_remote_code=Trueis correct.This change properly addresses the security concern by removing the hardcoded
trust_remote_code=True, allowing the default safe behavior (False).For full alignment with the coding guideline pattern ("let the caller decide via a parameter; default to
False"), consider exposing this as an optional parameter for callers who may need to load datasets that require remote code execution in the future.♻️ Optional: Expose parameter to callers
-def _get_speech_dataset(dataset_name: str, num_samples: int): +def _get_speech_dataset(dataset_name: str, num_samples: int, trust_remote_code: bool = False): """Load a portion of train dataset with the dataset name and a given size. Args: dataset_name: Name of the dataset to load. num_samples: Number of samples to load from the dataset. + trust_remote_code: Whether to trust remote code when loading the dataset. Returns: A hugging face Dataset. """ # Load the dataset if dataset_name in SUPPORTED_SPEECH_DATASET_CONFIG: from datasets import load_dataset # Use streaming can reduce the downloading time for large datasets dataset = load_dataset( - **SUPPORTED_SPEECH_DATASET_CONFIG[dataset_name]["config"], streaming=True + **SUPPORTED_SPEECH_DATASET_CONFIG[dataset_name]["config"], + streaming=True, + trust_remote_code=trust_remote_code, )And similarly propagate the parameter through
get_speech_dataset_dataloader.As per coding guidelines: "Do not hardcode
trust_remote_code=Truewhen loading Hugging Face Transformers models. Let the caller decide via a parameter; default toFalse."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/utils/speech_dataset_utils.py` around lines 50 - 52, The dataset loading call uses load_dataset without exposing trust_remote_code; add an optional parameter (e.g., trust_remote_code: bool = False) to the function that contains the load_dataset call and pass that parameter into load_dataset (using SUPPORTED_SPEECH_DATASET_CONFIG[dataset_name]["config"] plus trust_remote_code=trust_remote_code), and also propagate this new parameter through the public helper get_speech_dataset_dataloader so callers can opt into remote code execution while the default remains False.tests/unit/torch/utils/test_serialization.py (1)
48-71: LGTM - Basic functionality tests.Consider adding tests for
safe_saveand the_sanitize_for_savefunction to verify container subclass handling (e.g.,defaultdict→dictconversion).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/torch/utils/test_serialization.py` around lines 48 - 71, Add unit tests that exercise safe_save and the internal _sanitize_for_save to ensure container subclasses are converted to plain containers (e.g., collections.defaultdict -> dict) before serialization; create a test that builds a state containing a defaultdict and other subclassed containers, call safe_save to a BytesIO or temp file and then safe_load (or inspect the serialized result) to assert the loaded/serialized types are plain dict/list/tuple as appropriate; reference the safe_save and _sanitize_for_save functions to locate the serialization logic to cover this behavior.examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py (1)
617-617: Consider usingsafe_savefor consistency.Line 617 uses
torch.savedirectly while other parts of the codebase usesafe_save. Whiletorch.saveisn't a security risk (loading is the concern), usingsafe_savewould ensure the saved state is compatible withweights_only=Trueloading.♻️ Optional consistency fix
- torch.save(modelopt_state, str(modelopt_path)) + from modelopt.torch.utils import safe_save + safe_save(modelopt_state, str(modelopt_path))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py` at line 617, Replace the direct torch.save call used to write modelopt_state to modelopt_path with the project's safe_save utility to maintain consistency and ensure the saved state is compatible with weights_only=True loading; locate the occurrence where torch.save(modelopt_state, str(modelopt_path)) is called (referencing the modelopt_state and modelopt_path symbols) and swap it to call safe_save(modelopt_state, modelopt_path) or the project's equivalent safe_save API, ensuring any required imports or path type adjustments are added.modelopt/torch/utils/serialization.py (1)
28-48: Minor redundancy in list subclass handling.Lines 43-45 have identical behavior in both branches - both return
sanitized_list. The conditiontype(obj) is listcheck and the else branch produce the same result.♻️ Simplify list handling
if isinstance(obj, list): sanitized_list = [_sanitize_for_save(v) for v in obj] - if type(obj) is list: - return sanitized_list return sanitized_list🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/utils/serialization.py` around lines 28 - 48, The list-handling in _sanitize_for_save is redundant: it computes sanitized_list then checks "if type(obj) is list" but both branches return sanitized_list; remove the conditional and simply return sanitized_list unconditionally to simplify the function (refer to the _sanitize_for_save function and the sanitized_list variable and the existing type(obj) is list check to locate the code to change).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/gpt-oss/convert_oai_mxfp4_weight_only.py`:
- Line 125: The kwargs dictionary passed to AutoModelForCausalLM.from_pretrained
uses an incorrect key "dtype" so the model's precision is ignored; update the
kwargs declaration (the variable named kwargs) to use "torch_dtype" instead of
"dtype" (preserving "device_map": "auto" and "trust_remote_code":
args.trust_remote_code) so AutoModelForCausalLM.from_pretrained receives the
proper parameter.
In `@examples/llm_autodeploy/run_auto_quantize.py`:
- Around line 207-211: The help text for the argparse flag "--trust_remote_code"
contains a typo ("trust_remotecode"); update the help string in the
parser.add_argument call that defines "--trust_remote_code" to read "Set
trust_remote_code for Huggingface models and tokenizers" so the option name and
help text match exactly.
In `@examples/llm_qad/data_utils/download_dataset.py`:
- Around line 162-166: Fix the typo in the argparse help string for the
"--trust_remote_code" option: update the help text in the p.add_argument call
(the argument name "--trust_remote_code") to read "Set trust_remote_code for
Huggingface models and tokenizers" instead of "trust_remotecode".
In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py`:
- Line 91: The help string for the argument referring to remote code contains a
typo ("trust_remotecode"); update the help text to "trust_remote_code" for
consistency, i.e., locate the add_argument call that defines the
trust_remote_code flag (the help= parameter on that argument in
compute_hidden_states_hf.py) and change the help value to use
"trust_remote_code".
In `@examples/speculative_decoding/scripts/export_hf_checkpoint.py`:
- Line 32: The CLI flag parser currently uses
parser.add_argument("--trust_remote_code", type=bool, default=False, ...) which
misparses strings (bool("False") == True); change this to use a boolean flag
with action="store_true" (e.g., parser.add_argument("--trust_remote_code",
action="store_true", default=False, help=...)) so that passing
--trust_remote_code sets True and omitting it keeps False; update the
parser.add_argument call for the trust_remote_code option accordingly.
In `@modelopt/torch/prune/importance_hooks/base_hooks.py`:
- Line 737: The call to safe_load in base_hooks.py (where activation_data =
safe_load(activation_file)) can load tensors onto their original device; change
the call to explicitly pass map_location="cpu" so tensors are deserialized onto
CPU for safe aggregation across ranks and CPU-only environments; update the
safe_load(activation_file) invocation to safe_load(activation_file,
map_location="cpu") (or the equivalent parameter name used by your deserializer)
so activation_data is guaranteed to be on CPU before further processing.
---
Outside diff comments:
In `@examples/llm_eval/modeling.py`:
- Line 436: The call to torch.load when loading the checkpoint into
model.load_state_dict uses torch.load(checkpoint) which can deserialize
arbitrary objects; change it to pass weights_only=True (i.e., call
torch.load(checkpoint, weights_only=True)) so only tensor weights are
deserialized before calling model.load_state_dict(..., strict=False); update the
invocation near model.load_state_dict to use torch.load with weights_only=True
and, if necessary, add a short compatibility guard or comment if running on
older PyTorch versions that do not support weights_only.
In `@examples/windows/onnx_ptq/whisper/whisper_optimum_ort_inference.py`:
- Around line 87-122: The final summary print uses wer_result, references and
predictions even when args.run_wer_test is false, causing an uninitialized
variable error; fix by guarding the computation and the print with the same
condition (args.run_wer_test) or initialize defaults: compute wer_result only
when load("wer") is run and only call the print that references
wer_result/references/predictions inside the if args.run_wer_test block (or set
safe defaults for wer_result, references, predictions before the if) so
model.generate/processor decode code paths and the final print do not reference
undefined symbols.
---
Nitpick comments:
In
`@examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py`:
- Line 617: Replace the direct torch.save call used to write modelopt_state to
modelopt_path with the project's safe_save utility to maintain consistency and
ensure the saved state is compatible with weights_only=True loading; locate the
occurrence where torch.save(modelopt_state, str(modelopt_path)) is called
(referencing the modelopt_state and modelopt_path symbols) and swap it to call
safe_save(modelopt_state, modelopt_path) or the project's equivalent safe_save
API, ensuring any required imports or path type adjustments are added.
In `@modelopt/torch/utils/serialization.py`:
- Around line 28-48: The list-handling in _sanitize_for_save is redundant: it
computes sanitized_list then checks "if type(obj) is list" but both branches
return sanitized_list; remove the conditional and simply return sanitized_list
unconditionally to simplify the function (refer to the _sanitize_for_save
function and the sanitized_list variable and the existing type(obj) is list
check to locate the code to change).
In `@modelopt/torch/utils/speech_dataset_utils.py`:
- Around line 50-52: The dataset loading call uses load_dataset without exposing
trust_remote_code; add an optional parameter (e.g., trust_remote_code: bool =
False) to the function that contains the load_dataset call and pass that
parameter into load_dataset (using
SUPPORTED_SPEECH_DATASET_CONFIG[dataset_name]["config"] plus
trust_remote_code=trust_remote_code), and also propagate this new parameter
through the public helper get_speech_dataset_dataloader so callers can opt into
remote code execution while the default remains False.
In `@tests/unit/torch/utils/test_serialization.py`:
- Around line 48-71: Add unit tests that exercise safe_save and the internal
_sanitize_for_save to ensure container subclasses are converted to plain
containers (e.g., collections.defaultdict -> dict) before serialization; create
a test that builds a state containing a defaultdict and other subclassed
containers, call safe_save to a BytesIO or temp file and then safe_load (or
inspect the serialized result) to assert the loaded/serialized types are plain
dict/list/tuple as appropriate; reference the safe_save and _sanitize_for_save
functions to locate the serialization logic to cover this behavior.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 972acaa8-498b-4f10-ac29-85d073133431
📒 Files selected for processing (38)
CHANGELOG.rstexamples/gpt-oss/convert_oai_mxfp4_weight_only.pyexamples/llm_autodeploy/run_auto_quantize.pyexamples/llm_eval/lm_eval_hf.pyexamples/llm_eval/modeling.pyexamples/llm_ptq/example_utils.pyexamples/llm_ptq/hf_ptq.pyexamples/llm_ptq/vlm_utils.pyexamples/llm_qad/data_utils/download_dataset.pyexamples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.pyexamples/speculative_decoding/scripts/ar_validate.pyexamples/speculative_decoding/scripts/export_hf_checkpoint.pyexamples/speculative_decoding/scripts/send_conversation_vllm.pyexamples/windows/accuracy_benchmark/mmlu_benchmark.pyexamples/windows/accuracy_benchmark/modeling.pyexamples/windows/onnx_ptq/whisper/README.mdexamples/windows/onnx_ptq/whisper/whisper_onnx_quantization.pyexamples/windows/onnx_ptq/whisper/whisper_optimum_ort_inference.pyexamples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.pymodelopt/torch/export/distribute.pymodelopt/torch/export/model_config.pymodelopt/torch/opt/config.pymodelopt/torch/opt/conversion.pymodelopt/torch/opt/hparam.pymodelopt/torch/opt/plugins/mcore_dist_checkpointing.pymodelopt/torch/opt/plugins/megatron.pymodelopt/torch/opt/plugins/peft.pymodelopt/torch/opt/searcher.pymodelopt/torch/prune/importance_hooks/base_hooks.pymodelopt/torch/prune/importance_hooks/compare_module_outputs.pymodelopt/torch/quantization/qtensor/base_qtensor.pymodelopt/torch/utils/__init__.pymodelopt/torch/utils/serialization.pymodelopt/torch/utils/speech_dataset_utils.pytests/gpu/torch/export/test_vllm_fakequant_hf_export.pytests/gpu/torch/quantization/test_gptq.pytests/unit/torch/quantization/test_autoquant.pytests/unit/torch/utils/test_serialization.py
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
Outdated
Show resolved
Hide resolved
8b0c7c2 to
a517e9e
Compare
h-guo18
left a comment
There was a problem hiding this comment.
examples/speculative_decoding changes LGTM
|
Had a quick look at examples/windows changes - looks okay to me. |
…int loading Co-authored-by: RinZ27 <222222878+RinZ27@users.noreply.github.com> Signed-off-by: RinZ27 <222222878+RinZ27@users.noreply.github.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1a3ae07 to
3081b2f
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (2)
modelopt/torch/export/distribute.py (1)
78-78: Consider usingsafe_savefor consistency.Line 78 uses
torch.savewhile line 95 usessafe_load. For consistency with the PR's secure checkpoint handling, consider updating this to usesafe_save:- torch.save({"config": config_json, "weight": weights}, self.state_path) + safe_save({"config": config_json, "weight": weights}, self.state_path)This would require adding
safe_saveto the import at line 28.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/distribute.py` at line 78, Replace the direct call to torch.save in the checkpoint export with the project's safe_save utility to keep secure/consistent checkpoint handling: change the call that currently saves {"config": config_json, "weight": weights} to use safe_save(self.state_path, {"config": config_json, "weight": weights}) and add safe_save to the imports alongside the existing imports so the function is available; ensure you reference the same state_path and payload used by torch.save and keep semantics identical to the following safe_save usage elsewhere (matching safe_load usage).modelopt/torch/opt/conversion.py (1)
513-513: Consider usingsafe_saveinsave()for consistency.The
save()function usestorch.savewhilerestore()andload_modelopt_state()now usesafe_load. For consistency and to ensure the saved format is compatible withweights_only=Trueloading, consider updating to usesafe_save:- torch.save(ckpt_dict, f, **kwargs) + from modelopt.torch.utils import safe_save + safe_save(ckpt_dict, f, **kwargs)Or add
safe_saveto the import at line 37.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/opt/conversion.py` at line 513, The save() function currently calls torch.save(ckpt_dict, f, **kwargs); change it to use safe_save(ckpt_dict, f, **kwargs) so saved checkpoints are compatible with the existing safe_load-based restore() and load_modelopt_state() behavior and weights_only=True loading; also add safe_save to the module imports (where other torch helpers are imported) so the symbol is available. Ensure you only replace torch.save with safe_save in the save() implementation and update imports to include safe_save.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@modelopt/torch/export/distribute.py`:
- Line 78: Replace the direct call to torch.save in the checkpoint export with
the project's safe_save utility to keep secure/consistent checkpoint handling:
change the call that currently saves {"config": config_json, "weight": weights}
to use safe_save(self.state_path, {"config": config_json, "weight": weights})
and add safe_save to the imports alongside the existing imports so the function
is available; ensure you reference the same state_path and payload used by
torch.save and keep semantics identical to the following safe_save usage
elsewhere (matching safe_load usage).
In `@modelopt/torch/opt/conversion.py`:
- Line 513: The save() function currently calls torch.save(ckpt_dict, f,
**kwargs); change it to use safe_save(ckpt_dict, f, **kwargs) so saved
checkpoints are compatible with the existing safe_load-based restore() and
load_modelopt_state() behavior and weights_only=True loading; also add safe_save
to the module imports (where other torch helpers are imported) so the symbol is
available. Ensure you only replace torch.save with safe_save in the save()
implementation and update imports to include safe_save.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 17a54433-cf2b-4f8d-80ef-e391e158845f
📒 Files selected for processing (39)
CHANGELOG.rstexamples/gpt-oss/convert_oai_mxfp4_weight_only.pyexamples/llm_autodeploy/run_auto_quantize.pyexamples/llm_eval/lm_eval_hf.pyexamples/llm_eval/modeling.pyexamples/llm_ptq/example_utils.pyexamples/llm_ptq/hf_ptq.pyexamples/llm_ptq/vlm_utils.pyexamples/llm_qad/data_utils/download_dataset.pyexamples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.pyexamples/speculative_decoding/scripts/ar_validate.pyexamples/speculative_decoding/scripts/export_hf_checkpoint.pyexamples/speculative_decoding/scripts/send_conversation_vllm.pyexamples/windows/accuracy_benchmark/mmlu_benchmark.pyexamples/windows/accuracy_benchmark/modeling.pyexamples/windows/onnx_ptq/whisper/README.mdexamples/windows/onnx_ptq/whisper/whisper_onnx_quantization.pyexamples/windows/onnx_ptq/whisper/whisper_optimum_ort_inference.pyexamples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.pymodelopt/torch/export/distribute.pymodelopt/torch/export/model_config.pymodelopt/torch/opt/config.pymodelopt/torch/opt/conversion.pymodelopt/torch/opt/hparam.pymodelopt/torch/opt/plugins/mcore_dist_checkpointing.pymodelopt/torch/opt/plugins/megatron.pymodelopt/torch/opt/plugins/peft.pymodelopt/torch/opt/searcher.pymodelopt/torch/prune/importance_hooks/base_hooks.pymodelopt/torch/prune/importance_hooks/compare_module_outputs.pymodelopt/torch/quantization/calib/calibrator.pymodelopt/torch/quantization/qtensor/base_qtensor.pymodelopt/torch/utils/__init__.pymodelopt/torch/utils/serialization.pymodelopt/torch/utils/speech_dataset_utils.pytests/gpu/torch/export/test_vllm_fakequant_hf_export.pytests/gpu/torch/quantization/test_gptq.pytests/unit/torch/quantization/test_autoquant.pytests/unit/torch/utils/test_serialization.py
✅ Files skipped from review due to trivial changes (5)
- examples/windows/onnx_ptq/whisper/README.md
- tests/gpu/torch/quantization/test_gptq.py
- CHANGELOG.rst
- tests/unit/torch/utils/test_serialization.py
- examples/llm_ptq/example_utils.py
🚧 Files skipped from review as they are similar to previous changes (24)
- examples/windows/onnx_ptq/whisper/whisper_optimum_ort_inference.py
- examples/llm_eval/lm_eval_hf.py
- examples/speculative_decoding/scripts/send_conversation_vllm.py
- examples/llm_ptq/hf_ptq.py
- modelopt/torch/utils/init.py
- examples/speculative_decoding/scripts/export_hf_checkpoint.py
- examples/speculative_decoding/scripts/ar_validate.py
- examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
- modelopt/torch/utils/speech_dataset_utils.py
- examples/windows/accuracy_benchmark/mmlu_benchmark.py
- tests/unit/torch/quantization/test_autoquant.py
- modelopt/torch/opt/plugins/megatron.py
- modelopt/torch/opt/plugins/peft.py
- modelopt/torch/quantization/qtensor/base_qtensor.py
- modelopt/torch/export/model_config.py
- modelopt/torch/prune/importance_hooks/base_hooks.py
- examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py
- modelopt/torch/quantization/calib/calibrator.py
- examples/windows/onnx_ptq/whisper/whisper_onnx_quantization.py
- examples/llm_qad/data_utils/download_dataset.py
- modelopt/torch/utils/serialization.py
- examples/llm_eval/modeling.py
- examples/llm_ptq/vlm_utils.py
- modelopt/torch/opt/plugins/mcore_dist_checkpointing.py
What does this PR do?
torch.serialization.add_safe_globals([cls]). This also removes 1 existing pickle usage.trust_remote_code=TrueTesting
CICD tests ran
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅Additional Information
NVBug: 5999336
Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Tests