[Feat]FakeBaseModel for offline eagle; Kimi-K2.5 fixes;#1052
[Feat]FakeBaseModel for offline eagle; Kimi-K2.5 fixes;#1052
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds FakeBaseModel and FakeBaseConfig for offline training, replaces load_vlm_or_llm_with_kwargs with load_vlm_or_llm (explicit flags), threads use_fake_base_for_offline and trust_remote_code through launch scripts and examples, makes .pt discovery recursive with rglob, and adds integration and unit tests for offline/fake-base flows. (50 words) Changes
Sequence Diagram(s)sequenceDiagram
participant Launcher as launch_train.sh
participant Example as main.py / scripts
participant Loader as load_vlm_or_llm
participant FakeBase as FakeBaseModel
participant HFHub as HuggingFace Hub / Filesystem
Launcher->>Example: start training (flags: use_fake_base_for_offline, trust_remote_code, offline-data)
Example->>Loader: load_vlm_or_llm(model_path, use_fake_base=..., use_offline_training=..., trust_remote_code=...)
alt use_fake_base && use_offline_training
Loader->>FakeBase: FakeBaseModel.from_source(source, trust_remote_code)
FakeBase->>HFHub: fetch model.safetensors.index.json (local or hub)
FakeBase->>HFHub: download shard files
HFHub-->>FakeBase: return weights
FakeBase-->>Loader: return FakeBaseModel
else regular path
Loader->>HFHub: AutoConfig.from_pretrained(..., trust_remote_code)
Loader->>HFHub: AutoModelForCausalLM/AutoModelForVision2Seq.from_pretrained (maybe num_hidden_layers=0)
HFHub-->>Loader: model artifacts
Loader-->>Example: return model
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Important Pre-merge checks failedPlease resolve all errors before merging. Addressing warnings is optional. ❌ Failed checks (1 warning, 2 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1052 +/- ##
==========================================
- Coverage 70.21% 70.19% -0.02%
==========================================
Files 228 229 +1
Lines 25952 26023 +71
==========================================
+ Hits 18221 18268 +47
- Misses 7731 7755 +24 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
modelopt/torch/speculative/utils.py (1)
487-508: LGTM:trust_remote_codeproperly parameterized inload_vlm_or_llm.The function correctly exposes
trust_remote_codeas a caller-configurable parameter defaulting toFalse, complying with SECURITY.md guidelines.Minor: The docstring is missing the
use_fake_baseparameter description.📝 Proposed docstring fix
Args: model_name_or_path: Local path or HuggingFace repo ID of the model. + use_fake_base: Whether to use FakeBaseModel for offline training (default True). use_offline_training: Whether to load a memory-efficient model for offline training. torch_dtype: dtype to use when loading the model. device_map: Device map passed to ``from_pretrained``. trust_remote_code: Whether to trust remote code.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/utils.py` around lines 487 - 508, The docstring for load_vlm_or_llm is missing a description for the use_fake_base parameter; update the Args section to add a one-line explanation for use_fake_base (what it toggles and its default behavior), e.g., indicate that use_fake_base controls whether a FakeBaseModel is used when loading (default True) and how it interacts with use_offline_training, so readers can understand its purpose alongside model_name_or_path, use_offline_training, torch_dtype, device_map, and trust_remote_code.modelopt/torch/speculative/plugins/modeling_fakebase.py (1)
126-129: Consider using explicit exceptions instead ofassertfor shape validation.Assertions can be disabled with
-O(optimized mode). For production code that validates external checkpoint data, explicitValueErrororRuntimeErrorwould be more robust.♻️ Proposed refactor
- assert lm_head_w.shape == (config.vocab_size, config.hidden_size) - assert embed_tokens_w.shape == (config.vocab_size, config.hidden_size) + if lm_head_w.shape != (config.vocab_size, config.hidden_size): + raise ValueError( + f"lm_head shape mismatch: expected {(config.vocab_size, config.hidden_size)}, " + f"got {lm_head_w.shape}" + ) + if embed_tokens_w.shape != (config.vocab_size, config.hidden_size): + raise ValueError( + f"embed_tokens shape mismatch: expected {(config.vocab_size, config.hidden_size)}, " + f"got {embed_tokens_w.shape}" + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/modeling_fakebase.py` around lines 126 - 129, Replace the two assert checks with explicit runtime validation that raises a clear exception (e.g., ValueError) when shapes mismatch: check that lm_head_w.shape == (config.vocab_size, config.hidden_size) and embed_tokens_w.shape == (config.vocab_size, config.hidden_size) and if not raise an error that includes the actual shapes and expected dimensions; keep the subsequent assignments to self.lm_head.weight.data.copy_(lm_head_w) and self.embed_tokens.weight.data.copy_(embed_tokens_w) unchanged so that invalid checkpoint data fails loudly in production rather than being skipped when Python assertions are disabled.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/examples/speculative_decoding/test_eagle.py`:
- Around line 225-227: The code sets model_id to the hardcoded string
"kimi-k2.5" when model_source is provided, causing remote-model runs to share
the same eagle_output_dir; change the assignment for model_id to derive a unique
identifier from model_source (e.g., use the repository/name suffix:
model_source.split("/")[-1] or a sanitized version of that string) so model_id
is unique per model_source and output_subdir = eagle_output_dir /
f"eagle-{model_id}-offline" writes to distinct directories; update the model_id
assignment near the model_path/model_id definitions used in the test file
(model_path, model_id, eagle_output_dir).
---
Nitpick comments:
In `@modelopt/torch/speculative/plugins/modeling_fakebase.py`:
- Around line 126-129: Replace the two assert checks with explicit runtime
validation that raises a clear exception (e.g., ValueError) when shapes
mismatch: check that lm_head_w.shape == (config.vocab_size, config.hidden_size)
and embed_tokens_w.shape == (config.vocab_size, config.hidden_size) and if not
raise an error that includes the actual shapes and expected dimensions; keep the
subsequent assignments to self.lm_head.weight.data.copy_(lm_head_w) and
self.embed_tokens.weight.data.copy_(embed_tokens_w) unchanged so that invalid
checkpoint data fails loudly in production rather than being skipped when Python
assertions are disabled.
In `@modelopt/torch/speculative/utils.py`:
- Around line 487-508: The docstring for load_vlm_or_llm is missing a
description for the use_fake_base parameter; update the Args section to add a
one-line explanation for use_fake_base (what it toggles and its default
behavior), e.g., indicate that use_fake_base controls whether a FakeBaseModel is
used when loading (default True) and how it interacts with use_offline_training,
so readers can understand its purpose alongside model_name_or_path,
use_offline_training, torch_dtype, device_map, and trust_remote_code.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: d0e27af8-82ec-41a0-a296-d037db3dbd86
📒 Files selected for processing (9)
examples/speculative_decoding/eagle_utils.pyexamples/speculative_decoding/launch_train.shexamples/speculative_decoding/main.pyexamples/speculative_decoding/scripts/ar_validate.pyexamples/speculative_decoding/scripts/export_hf_checkpoint.pymodelopt/torch/speculative/plugins/modeling_fakebase.pymodelopt/torch/speculative/plugins/transformers.pymodelopt/torch/speculative/utils.pytests/examples/speculative_decoding/test_eagle.py
ChenhanYu
left a comment
There was a problem hiding this comment.
This PR introduces FakeBaseModel for memory-efficient offline EAGLE training (loads only lm_head + embed_tokens), refactors load_vlm_or_llm_with_kwargs → load_vlm_or_llm with a cleaner API, consolidates model path constants, and adds several Kimi-K2.5 compatibility fixes.
Note: load_vlm_or_llm_with_kwargs → load_vlm_or_llm is a breaking API change — any downstream callers passing extra kwargs will break.
Test coverage is insufficient. The PR adds one integration test that exercises the happy path, but:
FakeBaseModelhas no unit tests (weight loading, path auto-detection, error cases like missing safetensors index or unrecognized weight keys)load_vlm_or_llmhas no unit tests for its three distinct code paths (fake base, offline withnum_hidden_layers=0, normal load)- The removed
CompressedTensorsConfigignore patch and attention maskrepeathave no regression tests proving they're safe to remove patched_decoder_layer_fwd(past_key_value/past_key_values fix) is untested- 3 of 4 test cases download from remote HF repos — needs
@pytest.mark.slowor similar for CI reliability
Please add:
- Unit tests for
FakeBaseModel(local path happy path, missing index, wrong keys) - Unit tests for
load_vlm_or_llmcode paths @pytest.mark.slowor network markers for the remote-model tests
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/modeling_fakebase.py`:
- Around line 21-27: The module currently hard-imports optional packages
(transformers, huggingface_hub, safetensors) at top-level; change to guarded
imports following the codebase pattern: wrap imports in try/except and set
boolean flags (e.g., HAS_TRANSFORMERS, HAS_HF_HUB, HAS_SAFETENSORS) or move
those imports into the functions/methods that use them (e.g., where
hf_hub_download, EntryNotFoundError, safetensors_load_file, PretrainedConfig,
PreTrainedModel are referenced or where FakeBaseModel is constructed); ensure
any code that requires these libs checks the flags and raises a clear
ImportError if used without the [hf] extras.
- Line 115: The _load_weights() routine currently requires both 'lm_head.weight'
and 'embed_tokens.weight' even when config.tie_word_embeddings is True; update
_load_weights() to check self.config.tie_word_embeddings and, if True, call
_find_weight_key() once for the existing key (prefer 'embed_tokens.weight' but
accept either) and then assign that same tensor to both lm_head and embed_tokens
rather than requiring both keys; keep the existing behavior (separately finding
both keys) when tie_word_embeddings is False and continue to use
_find_weight_key() for key discovery and error reporting.
- Around line 109-115: The config construction and module creation currently
read getattr(base_cfg, "dtype", ...) and omit passing dtype to modules; change
to read getattr(base_cfg, "torch_dtype", None) (falling back to torch.bfloat16
only if None) when creating FakeBaseConfig and then ensure that created modules
(Embedding, Linear, and any parameter tensors) are instantiated with that dtype
so weights are allocated in the checkpoint dtype; update references around
FakeBaseConfig construction and places that instantiate nn.Embedding / nn.Linear
/ torch.nn.Parameter so they pass dtype=self.model.dtype (or the local
torch_dtype) to preserve fp16/bf16 and keep self.model.dtype accurate.
In `@modelopt/torch/speculative/utils.py`:
- Around line 514-516: The code currently reads attributes like
num_hidden_layers and layer_types directly from model_config (created via
transformers.AutoConfig.from_pretrained), but composite VLM configs nest these
under text_config/llm_config; apply the same unwrapping used in FakeBaseModel by
iterating _VLM_CONFIG_ATTRS (or using the same helper) to drill into the actual
LM config before checking hasattr(model_config, "layer_types") and before
assigning num_orig_hidden_layers, so that you first replace model_config with
the unwrapped sub-config (if present) and then proceed to read num_hidden_layers
and layer_types.
- Around line 450-452: In patched_decoder_layer_fwd, avoid clobbering an
existing legacy kwarg by only mapping the new name when present: if
"past_key_values" exists in kwargs, set kwargs["past_key_value"] =
kwargs.pop("past_key_values"); otherwise leave kwargs["past_key_value"]
untouched (do not set it to None). Update the logic in patched_decoder_layer_fwd
(which calls original_decoder_layer_forward) to perform this conditional
translation so callers that pass the old "past_key_value" continue to work.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: d3fe6aa2-e68e-478f-9a0b-7ecfa5ff4b6c
📒 Files selected for processing (4)
examples/speculative_decoding/launch_train.shmodelopt/torch/speculative/plugins/modeling_fakebase.pymodelopt/torch/speculative/utils.pytests/examples/speculative_decoding/test_eagle.py
🚧 Files skipped from review as they are similar to previous changes (2)
- examples/speculative_decoding/launch_train.sh
- tests/examples/speculative_decoding/test_eagle.py
yeyu-nvidia
left a comment
There was a problem hiding this comment.
Review
Issues
1. Security: rglob("*.pt") widens attack surface (eagle_utils.py:157)
Switching from glob to rglob means recursively discovering .pt files from arbitrary subdirectories. Since .pt uses pickle, a malicious file in a nested dir could execute arbitrary code. Per SECURITY.md best practices, consider using weights_only=True in the corresponding torch.load call, or at least documenting the trust assumption.
2. Shared path constants have new entries that change matching behavior (modeling_fakebase.py → transformers.py)
The new shared lists add entries not present in main:
_BASE_MODEL_PATHSadds"language_model.model"as the first entry_EMBED_TOKENS_PATHSadds"embed_tokens"and"language_model.model.embed_tokens"at positions 0 and 1
Since _find_base_model_parts matches the first hit, the new ordering could cause incorrect matches for existing models. For example, a model that has both a top-level embed_tokens (unrelated) and the real one at model.embed_tokens would now match the wrong one. The original ordering in main was carefully chosen — verify the new order doesn't break existing models (Llama, Gemma, etc.).
3. Missing num_orig_hidden_layers on FakeBaseModel path (utils.py:530-533)
When use_offline_training=True and use_fake_base=True, the function returns FakeBaseModel early without setting model.config.num_orig_hidden_layers. The non-fake offline path (line 562) does set this. The EAGLE training loop in main.py previously relied on this attribute. FakeBaseConfig does set num_hidden_layers from the original config, but num_orig_hidden_layers is never set. If downstream code reads num_orig_hidden_layers, this will AttributeError.
4. load_vlm_or_llm_with_kwargs removed without deprecation (utils.py)
The old function is deleted outright. If there are any external consumers (scripts, notebooks, other repos), this is a breaking change. A one-line deprecation wrapper forwarding to load_vlm_or_llm would be low-cost and safer.
5. Removed CompressedTensors quantization patch (transformers.py:577-580)
The CompressedTensorsConfig ignore patch for the EAGLE module was removed without explanation. On main, this prevents the drafter from getting quantized when loading Kimi-K2-Thinking with compressed tensors. If that model is still supported with compressed tensors checkpoints, removing this could cause the EAGLE drafter to be quantized unexpectedly.
6. Removed attention mask .repeat() for Kimi-K2 (transformers.py:730-731)
On main, tensor_mask.repeat(batch_size, 1, 1, 1) is present with the comment "repeat mask for kimi-k2 compatibility". The PR removes it. If Kimi-K2-Thinking still needs this, removing it will break that model's attention. The PR title says "Kimi-K2.5 fixes" — was this intentionally removed because K2.5 doesn't need it, or is K2-Thinking no longer supported?
7. FSDP default behavior change is silent (launch_train.sh:206)
On main, FSDP auto-enables for multi-GPU (TOTAL_GPU > 1). The PR changes this to require --fsdp=True explicitly. Existing users doing multi-GPU training without --fsdp will silently lose FSDP, potentially causing OOM or different training behavior. This should at minimum be documented in the PR description.
8. FakeBaseModel.__init__ signature is non-standard (modeling_fakebase.py:290)
The constructor takes source instead of the standard config parameter expected by PreTrainedModel. This means from_pretrained(), save_pretrained(), and other HF utilities won't work. Fine if intentional, but worth a docstring note that this class is construction-only and not compatible with the standard HF save/load API.
9. Tests download large remote models (test_eagle.py:238-241)
The parametrized test downloads configs/indices from moonshotai/Kimi-K2.5, moonshotai/Kimi-K2-Thinking, and MiniMaxAI/MiniMax-M2.5. These will be slow and require network access. They should be gated (e.g., @pytest.mark.slow or an env var check) to avoid breaking fast CI runs.
Minor/Style
modeling_fakebase.py: License header says "Copyright (c) 2024" — should be 2025 for new files.test_eagle.py:641: Missing spaces after commas in test IDs:"kimi-k2.5","kimi-k2-thinking","minimax-m2.5".FakeBaseConfig.dtypestores atorch.dtypeobject — this won't round-trip through HF's JSON config serialization cleanly.
What looks good
FakeBaseModeldesign is clean — loads only what's needed, auto-detects weight pathsload_vlm_or_llmwith explicit parameters is much better API than**kwargs- The
self._base_model_lm_headfix in offline mode is correct and important - Kimi decoder
past_key_value/past_key_valuespatch is a clean compatibility fix - Single-node GPU detection via
torch.cuda.device_count()respectingCUDA_VISIBLE_DEVICESis a good improvement - Removing hardcoded
trust_remote_code=Truein favor of a CLI flag is a security improvement - Unit tests are well-structured with proper monkeypatching
Comment 8 addressed in new commits. I think other comments does not make sense. Please let me know if you think they do. Thanks |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (4)
modelopt/torch/speculative/utils.py (1)
450-454:⚠️ Potential issue | 🟠 MajorUse conditional translation to avoid clobbering existing
past_key_valuekwarg.The current implementation unconditionally sets
past_key_valuetoNonewhenpast_key_valuesis absent. This breaks backward compatibility for callers that pass the legacypast_key_valueparameter directly. Only translate when the newer key is present.🐛 Proposed fix
def patched_decoder_layer_fwd(self, *args, **kwargs): - kwargs["past_key_value"] = kwargs.pop("past_key_values", None) + if "past_key_values" in kwargs: + kwargs.setdefault("past_key_value", kwargs.pop("past_key_values")) return original_decoder_layer_forward(self, *args, **kwargs)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/utils.py` around lines 450 - 454, patched_decoder_layer_fwd currently unconditionally sets kwargs["past_key_value"] = kwargs.pop("past_key_values", None), which overwrites callers that pass the legacy past_key_value; change it to only translate when the new key exists and avoid clobbering an existing legacy value: inside patched_decoder_layer_fwd check if "past_key_values" in kwargs, and if so, if "past_key_value" not in kwargs then set kwargs["past_key_value"] = kwargs.pop("past_key_values") else just kwargs.pop("past_key_values") to remove the duplicate, then call original_decoder_layer_forward(self, *args, **kwargs); keep assignment to kimi_k2_module.DeepseekV3DecoderLayer.forward as-is.modelopt/torch/speculative/plugins/modeling_fakebase.py (3)
197-211:⚠️ Potential issue | 🟠 MajorHandle tied-word-embedding checkpoints where only one weight key exists.
When
config.tie_word_embeddings=True, HuggingFace safetensors serialization deduplicates tied weights, saving only one copy (typicallyembed_tokens.weight). The current implementation requires both keys, causing valid tied-embedding checkpoints to fail withRuntimeError. Checkself.config.tie_word_embeddingsand reuse the surviving weight for both modules.🐛 Proposed fix
def _load_weights(self, source: str): """Load lm_head and embed_tokens weights from a local directory or HuggingFace Hub repo.""" weight_map = self._load_index(source) - lm_head_key = self._find_weight_key(weight_map, _LM_HEAD_PATHS, "lm_head") - embed_tokens_key = self._find_weight_key(weight_map, _EMBED_TOKENS_PATHS, "embed_tokens") - - lm_head_path, embed_tokens_path = self._resolve_shard_paths( - source, [weight_map[lm_head_key], weight_map[embed_tokens_key]] - ) - - lm_head_state = safetensors_load_file(lm_head_path, device="cpu") - embed_tokens_state = safetensors_load_file(embed_tokens_path, device="cpu") - - return lm_head_state[lm_head_key], embed_tokens_state[embed_tokens_key] + if self.config.tie_word_embeddings: + # Tied embeddings: only one weight is serialized, try embed_tokens first + try: + key = self._find_weight_key(weight_map, _EMBED_TOKENS_PATHS, "embed_tokens") + except RuntimeError: + key = self._find_weight_key(weight_map, _LM_HEAD_PATHS, "lm_head") + shard_path = self._resolve_shard_paths(source, [weight_map[key]])[0] + state = safetensors_load_file(shard_path, device="cpu") + weight = state[key] + return weight, weight # Same tensor for both + else: + lm_head_key = self._find_weight_key(weight_map, _LM_HEAD_PATHS, "lm_head") + embed_tokens_key = self._find_weight_key(weight_map, _EMBED_TOKENS_PATHS, "embed_tokens") + + lm_head_path, embed_tokens_path = self._resolve_shard_paths( + source, [weight_map[lm_head_key], weight_map[embed_tokens_key]] + ) + + lm_head_state = safetensors_load_file(lm_head_path, device="cpu") + embed_tokens_state = safetensors_load_file(embed_tokens_path, device="cpu") + + return lm_head_state[lm_head_key], embed_tokens_state[embed_tokens_key]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/modeling_fakebase.py` around lines 197 - 211, The _load_weights flow currently assumes both lm_head and embed_tokens keys exist; when self.config.tie_word_embeddings is True safetensors may store only one deduplicated key (typically embed_tokens.weight), causing a RuntimeError. Update _load_weights to detect missing second key by inspecting weight_map and self.config.tie_word_embeddings; if one key is absent, resolve shards for the surviving key and reuse its tensor for both lm_head and embed_tokens outputs (i.e., set lm_head_state_val = embed_tokens_state_val or vice versa) before returning. Use the existing helpers (_find_weight_key, _resolve_shard_paths, safetensors_load_file) and the variables lm_head_key, embed_tokens_key, weight_map to locate and load the surviving weight and return the same tensor for both modules when tied embeddings are enabled.
141-141:⚠️ Potential issue | 🟠 MajorUse
torch_dtypeattribute instead ofdtypeto read checkpoint dtype.HuggingFace
PretrainedConfigexposes the checkpoint dtype viatorch_dtype, notdtype. The current code readsdtypewhich doesn't exist, so it always defaults totorch.bfloat16regardless of the actual checkpoint dtype. This undermines dtype consistency.🐛 Proposed fix
config = FakeBaseConfig( num_hidden_layers=getattr(base_cfg, "num_hidden_layers", None), hidden_size=getattr(base_cfg, "hidden_size", None), vocab_size=getattr(base_cfg, "vocab_size", None), max_position_embeddings=getattr(base_cfg, "max_position_embeddings", None), - dtype=getattr(base_cfg, "dtype", torch.bfloat16), + dtype=getattr(base_cfg, "torch_dtype", None) or torch.bfloat16, tie_word_embeddings=getattr(base_cfg, "tie_word_embeddings", False), )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/modeling_fakebase.py` at line 141, The code reads checkpoint dtype using getattr(base_cfg, "dtype", torch.bfloat16) which is wrong; change it to read getattr(base_cfg, "torch_dtype", torch.bfloat16) (optionally fallback to "dtype" if you want compatibility) so the model uses the checkpoint's actual dtype; update the assignment where dtype is set in modeling_fakebase.py (the line currently using getattr(base_cfg, "dtype", ...)) to use "torch_dtype" (or a two-step check like base_cfg.torch_dtype or base_cfg.dtype) to preserve dtype consistency.
21-33: 🛠️ Refactor suggestion | 🟠 MajorGate optional dependency imports behind try/except or move to function scope.
Per coding guidelines, optional dependencies (
transformers,huggingface_hub,safetensors) must be gated. These hard imports at module level will fail if the[hf]extra is not installed, even ifFakeBaseModelis never used. The codebase pattern (seen ingradnas.py, etc.) uses try/except with flags or moves imports into functions.♻️ Proposed fix - move imports to function/method scope
-import torch -import torch.nn as nn -import transformers -from huggingface_hub import hf_hub_download -from huggingface_hub.errors import EntryNotFoundError -from safetensors.torch import load_file as safetensors_load_file -from transformers import ( - AutoConfig, - AutoModel, - AutoModelForCausalLM, - PretrainedConfig, - PreTrainedModel, -) +import torch +import torch.nn as nn + +try: + import transformers + from huggingface_hub import hf_hub_download + from huggingface_hub.errors import EntryNotFoundError + from safetensors.torch import load_file as safetensors_load_file + from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + PretrainedConfig, + PreTrainedModel, + ) + _HAS_HF_DEPS = True +except ImportError: + _HAS_HF_DEPS = FalseThen check
_HAS_HF_DEPSinfrom_sourceand raise a clear error if dependencies are missing.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/modeling_fakebase.py` around lines 21 - 33, Top-level imports for optional HF deps (transformers, huggingface_hub, safetensors) must be gated; move those imports into the function/method scope or wrap them in a try/except that sets a _HAS_HF_DEPS flag. Specifically, remove or relocate the module-level imports in modeling_fakebase.py and either (a) import AutoConfig/AutoModel/AutoModelForCausalLM/PretrainedConfig/PreTrainedModel, hf_hub_download, EntryNotFoundError and safetensors_load_file inside FakeBaseModel.from_source (or any helper it calls) or (b) wrap the top-level imports in try/except to set _HAS_HF_DEPS=False and then at the start of FakeBaseModel.from_source check _HAS_HF_DEPS and raise a clear error if missing; reference FakeBaseModel and its from_source method when applying the change.
🧹 Nitpick comments (1)
tests/examples/speculative_decoding/test_eagle.py (1)
231-233: Consider using attribute-based VLM detection instead of hardcoded model name.The current check only handles
moonshotai/Kimi-K2.5but other parametrized models (Kimi-K2-Thinking,MiniMax-M2.5) may also be VLMs requiring config unwrapping. This mirrors the pattern inFakeBaseModel.from_sourcewhich uses_VLM_CONFIG_ATTRS.♻️ Proposed fix
cfg = transformers.AutoConfig.from_pretrained(model_path, trust_remote_code=True) - if model_source=="moonshotai/Kimi-K2.5": - `#vlm`, get text config - cfg = cfg.text_config + # For VLMs, unwrap to the language model config + for attr in ["text_config", "llm_config"]: + if getattr(cfg, attr, None) is not None: + cfg = getattr(cfg, attr) + break🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/examples/speculative_decoding/test_eagle.py` around lines 231 - 233, The code currently checks model_source=="moonshotai/Kimi-K2.5" to decide to unwrap a VLM config and set cfg = cfg.text_config; instead, detect VLMs by attribute presence like FakeBaseModel.from_source does using _VLM_CONFIG_ATTRS: check for any attribute in _VLM_CONFIG_ATTRS on the model source or cfg object and if present set cfg = cfg.text_config (or the corresponding attr) so other parametrized VLMs (e.g., Kimi-K2-Thinking, MiniMax-M2.5) are handled without hardcoding model names; update the conditional around model_source and cfg to use that attribute-based test and reuse the same attribute list used by FakeBaseModel.from_source.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/utils.py`:
- Around line 514-538: The offline-training branch reads num_hidden_layers and
layer_types directly from model_config which fails for composite VLM configs;
before inspecting those attributes (in the block that sets extra and after
loading model_config) unwrap VLM configs the same way FakeBaseModel.from_source
does (use the same _VLM_CONFIG_ATTRS/text_config or llm_config lookup) so you
pull num_hidden_layers and layer_types from the nested text/llm config when
present; then set extra["num_hidden_layers"], extra["layer_types"] and later
model.config.num_orig_hidden_layers from the unwrapped values rather than the
top-level model_config fields.
---
Duplicate comments:
In `@modelopt/torch/speculative/plugins/modeling_fakebase.py`:
- Around line 197-211: The _load_weights flow currently assumes both lm_head and
embed_tokens keys exist; when self.config.tie_word_embeddings is True
safetensors may store only one deduplicated key (typically embed_tokens.weight),
causing a RuntimeError. Update _load_weights to detect missing second key by
inspecting weight_map and self.config.tie_word_embeddings; if one key is absent,
resolve shards for the surviving key and reuse its tensor for both lm_head and
embed_tokens outputs (i.e., set lm_head_state_val = embed_tokens_state_val or
vice versa) before returning. Use the existing helpers (_find_weight_key,
_resolve_shard_paths, safetensors_load_file) and the variables lm_head_key,
embed_tokens_key, weight_map to locate and load the surviving weight and return
the same tensor for both modules when tied embeddings are enabled.
- Line 141: The code reads checkpoint dtype using getattr(base_cfg, "dtype",
torch.bfloat16) which is wrong; change it to read getattr(base_cfg,
"torch_dtype", torch.bfloat16) (optionally fallback to "dtype" if you want
compatibility) so the model uses the checkpoint's actual dtype; update the
assignment where dtype is set in modeling_fakebase.py (the line currently using
getattr(base_cfg, "dtype", ...)) to use "torch_dtype" (or a two-step check like
base_cfg.torch_dtype or base_cfg.dtype) to preserve dtype consistency.
- Around line 21-33: Top-level imports for optional HF deps (transformers,
huggingface_hub, safetensors) must be gated; move those imports into the
function/method scope or wrap them in a try/except that sets a _HAS_HF_DEPS
flag. Specifically, remove or relocate the module-level imports in
modeling_fakebase.py and either (a) import
AutoConfig/AutoModel/AutoModelForCausalLM/PretrainedConfig/PreTrainedModel,
hf_hub_download, EntryNotFoundError and safetensors_load_file inside
FakeBaseModel.from_source (or any helper it calls) or (b) wrap the top-level
imports in try/except to set _HAS_HF_DEPS=False and then at the start of
FakeBaseModel.from_source check _HAS_HF_DEPS and raise a clear error if missing;
reference FakeBaseModel and its from_source method when applying the change.
In `@modelopt/torch/speculative/utils.py`:
- Around line 450-454: patched_decoder_layer_fwd currently unconditionally sets
kwargs["past_key_value"] = kwargs.pop("past_key_values", None), which overwrites
callers that pass the legacy past_key_value; change it to only translate when
the new key exists and avoid clobbering an existing legacy value: inside
patched_decoder_layer_fwd check if "past_key_values" in kwargs, and if so, if
"past_key_value" not in kwargs then set kwargs["past_key_value"] =
kwargs.pop("past_key_values") else just kwargs.pop("past_key_values") to remove
the duplicate, then call original_decoder_layer_forward(self, *args, **kwargs);
keep assignment to kimi_k2_module.DeepseekV3DecoderLayer.forward as-is.
---
Nitpick comments:
In `@tests/examples/speculative_decoding/test_eagle.py`:
- Around line 231-233: The code currently checks
model_source=="moonshotai/Kimi-K2.5" to decide to unwrap a VLM config and set
cfg = cfg.text_config; instead, detect VLMs by attribute presence like
FakeBaseModel.from_source does using _VLM_CONFIG_ATTRS: check for any attribute
in _VLM_CONFIG_ATTRS on the model source or cfg object and if present set cfg =
cfg.text_config (or the corresponding attr) so other parametrized VLMs (e.g.,
Kimi-K2-Thinking, MiniMax-M2.5) are handled without hardcoding model names;
update the conditional around model_source and cfg to use that attribute-based
test and reuse the same attribute list used by FakeBaseModel.from_source.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: d60a3a31-ec3e-4aea-8a42-f8fc76dd86af
📒 Files selected for processing (4)
modelopt/torch/speculative/plugins/modeling_fakebase.pymodelopt/torch/speculative/utils.pytests/examples/speculative_decoding/test_eagle.pytests/unit/torch/speculative/plugins/test_fakebase.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/unit/torch/speculative/plugins/test_fakebase.py
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
c118d94 to
cb6802c
Compare
|
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/speculative_decoding/launch_train.sh (1)
220-226:⚠️ Potential issue | 🟡 MinorDocument FSDP behavior change.
FSDP is now disabled by default and requires explicit
--fsdp=Trueeven for multi-GPU setups. Previously, FSDP was auto-enabled whenTOTAL_GPU > 1. This is a breaking change for existing multi-GPU training workflows.Consider documenting this change in the PR description or a changelog entry.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/launch_train.sh` around lines 220 - 226, Update the PR description or project changelog to note that FSDP is now disabled by default and must be set with --fsdp=True even on multi-GPU systems (previously auto-enabled when TOTAL_GPU > 1); specifically mention the behavior around the launch script variables TOTAL_GPU, FSDP, and FSDP_ARGS (including the use of fsdp_config.json) and the potential breaking impact on multi-GPU workflows, and optionally add a short warning in launch_train.sh when TOTAL_GPU>1 and FSDP is not True so users see the change at runtime.
♻️ Duplicate comments (3)
modelopt/torch/speculative/utils.py (1)
450-452:⚠️ Potential issue | 🟡 MinorPreserve existing
past_key_valuekwarg.The current implementation unconditionally sets
past_key_valuefrompast_key_values, potentially overwriting an existingpast_key_valuekwarg withNonewhenpast_key_valuesis absent. Only translate whenpast_key_valuesis actually present.🔧 Proposed fix
def patched_decoder_layer_fwd(self, *args, **kwargs): - kwargs["past_key_value"] = kwargs.pop("past_key_values", None) + if "past_key_values" in kwargs: + kwargs.setdefault("past_key_value", kwargs.pop("past_key_values")) return original_decoder_layer_forward(self, *args, **kwargs)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/utils.py` around lines 450 - 452, patched_decoder_layer_fwd currently always sets kwargs["past_key_value"] = kwargs.pop("past_key_values", None) which can overwrite an existing past_key_value with None; change it to only transfer past_key_values when that key exists (e.g., if "past_key_values" in kwargs: kwargs["past_key_value"] = kwargs.pop("past_key_values")) before calling original_decoder_layer_forward(self, *args, **kwargs); preserve any existing past_key_value when past_key_values is absent.modelopt/torch/speculative/plugins/modeling_fakebase.py (2)
141-141:⚠️ Potential issue | 🟡 MinorUse
torch_dtypeinstead ofdtypefrom config.
AutoConfig.from_pretrained()exposes the checkpoint dtype viatorch_dtype, notdtype. This will always fall back totorch.bfloat16regardless of the actual checkpoint dtype.- dtype=getattr(base_cfg, "dtype", torch.bfloat16), + dtype=getattr(base_cfg, "torch_dtype", None) or torch.bfloat16,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/modeling_fakebase.py` at line 141, The config uses getattr(base_cfg, "dtype", torch.bfloat16) which is incorrect because AutoConfig exposes the checkpoint dtype as torch_dtype; update the call in modeling_fakebase.py (where base_cfg is used to set dtype) to use getattr(base_cfg, "torch_dtype", torch.bfloat16) so the actual checkpoint dtype is respected while keeping torch.bfloat16 as the fallback.
197-211:⚠️ Potential issue | 🟠 MajorHandle tied-word-embedding checkpoints.
When
config.tie_word_embeddings=True, HuggingFace safetensors typically saves only one copy of the tied weights (usuallyembed_tokens.weight). The current_load_weights()unconditionally requires both keys, which will causeRuntimeErrorfor valid tied-embedding checkpoints.Check
self.config.tie_word_embeddingsand reuse the surviving weight key for both tensors whenTrue.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/modeling_fakebase.py` around lines 197 - 211, The _load_weights method currently assumes both lm_head and embed_tokens keys exist; update it to handle tied embeddings by checking self.config.tie_word_embeddings and, if True, detect when one key is missing and reuse the present key for both tensors (e.g., if only embed_tokens_key exists, set lm_head_key = embed_tokens_key, or vice versa) before calling _resolve_shard_paths, safetensors_load_file, and returning the pair; adjust logic around _find_weight_key, lm_head_key, embed_tokens_key, _resolve_shard_paths, and the final return so the same loaded tensor is returned for both lm_head and embed_tokens when tie_word_embeddings is True.
🧹 Nitpick comments (3)
modelopt/torch/speculative/plugins/modeling_fakebase.py (1)
1-1: Update copyright year to 2025.The copyright year should be 2025 for new files added in 2025/2026.
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/modeling_fakebase.py` at line 1, Update the SPDX header year from 2024 to 2025 at the top of the file: change the copyright notice in the SPDX comment (the first line in modeling_fakebase.py) so it reads 2025 instead of 2024.tests/examples/speculative_decoding/test_eagle.py (2)
216-216: Minor: Add consistent spacing in test IDs.- ids=["tinyllama", "kimi-k2.5","kimi-k2-thinking","minimax-m2.5"], + ids=["tinyllama", "kimi-k2.5", "kimi-k2-thinking", "minimax-m2.5"],🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/examples/speculative_decoding/test_eagle.py` at line 216, In the tests/examples/speculative_decoding/test_eagle.py file update the ids list passed in the test (the ids=["tinyllama", "kimi-k2.5","kimi-k2-thinking","minimax-m2.5"] entry) to use consistent spacing after commas so each item is separated by ", " (e.g. "tinyllama", "kimi-k2.5", "kimi-k2-thinking", "minimax-m2.5") for consistent style.
274-308: Consider usingpytest.mark.dependencyor fixtures for test ordering.
test_offline_resume_training_kimidepends ontest_offline_eagle3_training["kimi-k2.5"]having run first to create the checkpoint ateagle-Kimi-K2.5-offline. While pytest typically runs tests in file order, this implicit dependency is fragile.Consider:
- Using
pytest-dependswith@pytest.mark.depends(on=...), or- Converting the initial training into a session-scoped fixture that the resume test depends on.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/examples/speculative_decoding/test_eagle.py` around lines 274 - 308, The resume test test_offline_resume_training_kimi has an implicit dependency on test_offline_eagle3_training["kimi-k2.5"] creating the checkpoint; make this explicit by either (A) adding a dependency marker like `@pytest.mark.depends`(on="test_offline_eagle3_training[kimi-k2.5]") above test_offline_resume_training_kimi, or (B) refactoring the initial training into a session-scoped fixture (e.g., eagle_kimi_checkpoint_fixture) that returns the checkpoint_dir and then inject that fixture into test_offline_resume_training_kimi (replace direct use of eagle_output_dir/eagle-Kimi-K2.5-offline), ensuring the resume test uses the fixture so pytest guarantees order and availability of the checkpoint.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@examples/speculative_decoding/launch_train.sh`:
- Around line 220-226: Update the PR description or project changelog to note
that FSDP is now disabled by default and must be set with --fsdp=True even on
multi-GPU systems (previously auto-enabled when TOTAL_GPU > 1); specifically
mention the behavior around the launch script variables TOTAL_GPU, FSDP, and
FSDP_ARGS (including the use of fsdp_config.json) and the potential breaking
impact on multi-GPU workflows, and optionally add a short warning in
launch_train.sh when TOTAL_GPU>1 and FSDP is not True so users see the change at
runtime.
---
Duplicate comments:
In `@modelopt/torch/speculative/plugins/modeling_fakebase.py`:
- Line 141: The config uses getattr(base_cfg, "dtype", torch.bfloat16) which is
incorrect because AutoConfig exposes the checkpoint dtype as torch_dtype; update
the call in modeling_fakebase.py (where base_cfg is used to set dtype) to use
getattr(base_cfg, "torch_dtype", torch.bfloat16) so the actual checkpoint dtype
is respected while keeping torch.bfloat16 as the fallback.
- Around line 197-211: The _load_weights method currently assumes both lm_head
and embed_tokens keys exist; update it to handle tied embeddings by checking
self.config.tie_word_embeddings and, if True, detect when one key is missing and
reuse the present key for both tensors (e.g., if only embed_tokens_key exists,
set lm_head_key = embed_tokens_key, or vice versa) before calling
_resolve_shard_paths, safetensors_load_file, and returning the pair; adjust
logic around _find_weight_key, lm_head_key, embed_tokens_key,
_resolve_shard_paths, and the final return so the same loaded tensor is returned
for both lm_head and embed_tokens when tie_word_embeddings is True.
In `@modelopt/torch/speculative/utils.py`:
- Around line 450-452: patched_decoder_layer_fwd currently always sets
kwargs["past_key_value"] = kwargs.pop("past_key_values", None) which can
overwrite an existing past_key_value with None; change it to only transfer
past_key_values when that key exists (e.g., if "past_key_values" in kwargs:
kwargs["past_key_value"] = kwargs.pop("past_key_values")) before calling
original_decoder_layer_forward(self, *args, **kwargs); preserve any existing
past_key_value when past_key_values is absent.
---
Nitpick comments:
In `@modelopt/torch/speculative/plugins/modeling_fakebase.py`:
- Line 1: Update the SPDX header year from 2024 to 2025 at the top of the file:
change the copyright notice in the SPDX comment (the first line in
modeling_fakebase.py) so it reads 2025 instead of 2024.
In `@tests/examples/speculative_decoding/test_eagle.py`:
- Line 216: In the tests/examples/speculative_decoding/test_eagle.py file update
the ids list passed in the test (the ids=["tinyllama",
"kimi-k2.5","kimi-k2-thinking","minimax-m2.5"] entry) to use consistent spacing
after commas so each item is separated by ", " (e.g. "tinyllama", "kimi-k2.5",
"kimi-k2-thinking", "minimax-m2.5") for consistent style.
- Around line 274-308: The resume test test_offline_resume_training_kimi has an
implicit dependency on test_offline_eagle3_training["kimi-k2.5"] creating the
checkpoint; make this explicit by either (A) adding a dependency marker like
`@pytest.mark.depends`(on="test_offline_eagle3_training[kimi-k2.5]") above
test_offline_resume_training_kimi, or (B) refactoring the initial training into
a session-scoped fixture (e.g., eagle_kimi_checkpoint_fixture) that returns the
checkpoint_dir and then inject that fixture into
test_offline_resume_training_kimi (replace direct use of
eagle_output_dir/eagle-Kimi-K2.5-offline), ensuring the resume test uses the
fixture so pytest guarantees order and availability of the checkpoint.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 8d810a1a-2446-4f38-ba0c-8b9fa350b134
📒 Files selected for processing (10)
examples/speculative_decoding/eagle_utils.pyexamples/speculative_decoding/launch_train.shexamples/speculative_decoding/main.pyexamples/speculative_decoding/scripts/ar_validate.pyexamples/speculative_decoding/scripts/export_hf_checkpoint.pymodelopt/torch/speculative/plugins/modeling_fakebase.pymodelopt/torch/speculative/plugins/transformers.pymodelopt/torch/speculative/utils.pytests/examples/speculative_decoding/test_eagle.pytests/unit/torch/speculative/plugins/test_fakebase.py
✅ Files skipped from review due to trivial changes (1)
- tests/unit/torch/speculative/plugins/test_fakebase.py
🚧 Files skipped from review as they are similar to previous changes (4)
- examples/speculative_decoding/eagle_utils.py
- examples/speculative_decoding/scripts/export_hf_checkpoint.py
- examples/speculative_decoding/scripts/ar_validate.py
- examples/speculative_decoding/main.py
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
What does this PR do?
Adds
FakeBaseModelfor offline EAGLE training and several Kimi-K2.5 compatibility fixes.FakeBaseModel— lightweight model that loads onlylm_headandembed_tokensfrom a local checkpoint, avoiding full model weight loading during offline training. Configured viaFakeBaseArgumentsand integrated intoload_vlm_or_llm._find_base_model_parts— support Kimi-K2.5 VLM layout (language_model.modelpath)past_key_value/past_key_valuesargument mismatchrglobfor.ptdiscovery in nested offline data dirs; single-node GPU count respectsCUDA_VISIBLE_DEVICESType of change: Bug fix, new feature
Testing
Tested offline EAGLE training for Kimi-K2.5 end-to-end.
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: N/AAdditional Information
Summary by CodeRabbit
New Features
Improvements
Bug Fixes
Tests