Skip to content

[Feat]FakeBaseModel for offline eagle; Kimi-K2.5 fixes;#1052

Open
h-guo18 wants to merge 6 commits intomainfrom
haoguo/fakebasemodel
Open

[Feat]FakeBaseModel for offline eagle; Kimi-K2.5 fixes;#1052
h-guo18 wants to merge 6 commits intomainfrom
haoguo/fakebasemodel

Conversation

@h-guo18
Copy link
Copy Markdown
Contributor

@h-guo18 h-guo18 commented Mar 17, 2026

What does this PR do?

Adds FakeBaseModel for offline EAGLE training and several Kimi-K2.5 compatibility fixes.

  • New: FakeBaseModel — lightweight model that loads only lm_head and embed_tokens from a local checkpoint, avoiding full model weight loading during offline training. Configured via FakeBaseArguments and integrated into load_vlm_or_llm.
  • Fix: _find_base_model_parts — support Kimi-K2.5 VLM layout (language_model.model path)
  • Fix: offline mode lm_head access and CompressedTensors ignore path
  • Fix: Kimi-K2.5 decoder past_key_value/past_key_values argument mismatch
  • Fix: rglob for .pt discovery in nested offline data dirs; single-node GPU count respects CUDA_VISIBLE_DEVICES

Type of change: Bug fix, new feature

Testing

Tested offline EAGLE training for Kimi-K2.5 end-to-end.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ❌
  • Did you update Changelog?: ❌

Additional Information

Summary by CodeRabbit

  • New Features

    • Lightweight fake-base model support for offline speculative-decoding training
  • Improvements

    • Added CLI flags: --use_fake_base_for_offline, --trust_remote_code, and --fsdp
    • Expanded offline .pt discovery to include nested subdirectories
    • Better GPU detection with explicit single-node logging; FSDP enabled only when requested
    • Model loading and launch tooling now honor offline and trust-remote-code flags
  • Bug Fixes

    • Improved compatibility with legacy transformer / Kimi-K2 call signatures
  • Tests

    • Added tests covering fake-base loading and offline training workflows

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 17, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 17, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds FakeBaseModel and FakeBaseConfig for offline training, replaces load_vlm_or_llm_with_kwargs with load_vlm_or_llm (explicit flags), threads use_fake_base_for_offline and trust_remote_code through launch scripts and examples, makes .pt discovery recursive with rglob, and adds integration and unit tests for offline/fake-base flows. (50 words)

Changes

Cohort / File(s) Summary
Example callsites
examples/speculative_decoding/main.py, examples/speculative_decoding/scripts/ar_validate.py, examples/speculative_decoding/scripts/export_hf_checkpoint.py
Replaced load_vlm_or_llm_with_kwargsload_vlm_or_llm; loader now returns only model. Call sites pass/propagate explicit flags (use_fake_base_for_offline, trust_remote_code) and stop unpacking (config, model).
Launch & utils (examples)
examples/speculative_decoding/launch_train.sh, examples/speculative_decoding/eagle_utils.py
Added CLI flags --use_fake_base_for_offline, --trust_remote_code, --fsdp; adjusted GPU/TOTAL_GPU detection and conditional FSDP enabling; eagle_utils.py now finds .pt files recursively with Path.rglob("*.pt").
Model loading core
modelopt/torch/speculative/utils.py
Removed load_vlm_or_llm_with_kwargs; added load_vlm_or_llm(model_name_or_path, use_fake_base, use_offline_training, torch_dtype, device_map, trust_remote_code) returning only model; added Kimi-K2 runtime monkey-patch and offline-loading adjustments (zeroed num_hidden_layers and recording original count).
Fake-base plugin
modelopt/torch/speculative/plugins/modeling_fakebase.py
New module introducing FakeBaseConfig and FakeBaseModel with from_source(...) to read HF config, locate safetensors index (local or hub), resolve shards, load embed_tokens and lm_head weights, validate shapes, and register types with HF Auto classes. Forward is unimplemented.
Transformers plugin tweaks
modelopt/torch/speculative/plugins/transformers.py
Centralized base/embed/lm-head path constants usage, removed earlier quantization patch, changed offline logits fallback to _base_model_lm_head(...), simplified attention-mask generation, and minor loop/variable changes.
Tests — integration & unit
tests/examples/speculative_decoding/test_eagle.py, tests/unit/torch/speculative/plugins/test_fakebase.py
Added generate_offline_pt_data, integration tests invoking launch_train.sh for offline training and resume cases, and unit tests for FakeBaseModel.from_source (safetensors index/shard handling) and load_vlm_or_llm offline/fake-base behaviors.

Sequence Diagram(s)

sequenceDiagram
  participant Launcher as launch_train.sh
  participant Example as main.py / scripts
  participant Loader as load_vlm_or_llm
  participant FakeBase as FakeBaseModel
  participant HFHub as HuggingFace Hub / Filesystem

  Launcher->>Example: start training (flags: use_fake_base_for_offline, trust_remote_code, offline-data)
  Example->>Loader: load_vlm_or_llm(model_path, use_fake_base=..., use_offline_training=..., trust_remote_code=...)
  alt use_fake_base && use_offline_training
    Loader->>FakeBase: FakeBaseModel.from_source(source, trust_remote_code)
    FakeBase->>HFHub: fetch model.safetensors.index.json (local or hub)
    FakeBase->>HFHub: download shard files
    HFHub-->>FakeBase: return weights
    FakeBase-->>Loader: return FakeBaseModel
  else regular path
    Loader->>HFHub: AutoConfig.from_pretrained(..., trust_remote_code)
    Loader->>HFHub: AutoModelForCausalLM/AutoModelForVision2Seq.from_pretrained (maybe num_hidden_layers=0)
    HFHub-->>Loader: model artifacts
    Loader-->>Example: return model
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes


Important

Pre-merge checks failed

Please resolve all errors before merging. Addressing warnings is optional.

❌ Failed checks (1 warning, 2 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title references FakeBaseModel and Kimi-K2.5 fixes, which are core changes in the PR, but the phrasing is vague and lacks specific context about the primary feature or fix. Consider clarifying the title to highlight the main change more explicitly, such as 'Add FakeBaseModel for offline EAGLE training' or 'Support offline EAGLE with FakeBaseModel and fix Kimi-K2.5 compatibility'.
Security Anti-Patterns ❓ Inconclusive Cannot access file system to view eagle_utils.py lines 70-90 directly; web search performed but specific file contents not available. Please provide the relevant code section from eagle_utils.py or clone the repository locally to inspect torch.load usage and weights_only parameter specification.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch haoguo/fakebasemodel

Comment @coderabbitai help to get the list of available commands and usage tips.

@h-guo18 h-guo18 changed the title Add FakeBaseModel for offline speculative decoding and Kimi-K2.5 fixes [Feat]FakeBaseModel for offline eagle; Kimi-K2.5 fixes; Mar 17, 2026
@codecov
Copy link
Copy Markdown

codecov bot commented Mar 17, 2026

Codecov Report

❌ Patch coverage is 57.89474% with 8 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.19%. Comparing base (291498b) to head (4613e80).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/speculative/utils.py 57.89% 8 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1052      +/-   ##
==========================================
- Coverage   70.21%   70.19%   -0.02%     
==========================================
  Files         228      229       +1     
  Lines       25952    26023      +71     
==========================================
+ Hits        18221    18268      +47     
- Misses       7731     7755      +24     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@h-guo18 h-guo18 self-assigned this Mar 17, 2026
@h-guo18 h-guo18 marked this pull request as ready for review March 18, 2026 17:45
@h-guo18 h-guo18 requested a review from a team as a code owner March 18, 2026 17:45
@h-guo18 h-guo18 requested a review from yeyu-nvidia March 18, 2026 17:45
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
modelopt/torch/speculative/utils.py (1)

487-508: LGTM: trust_remote_code properly parameterized in load_vlm_or_llm.

The function correctly exposes trust_remote_code as a caller-configurable parameter defaulting to False, complying with SECURITY.md guidelines.

Minor: The docstring is missing the use_fake_base parameter description.

📝 Proposed docstring fix
     Args:
         model_name_or_path: Local path or HuggingFace repo ID of the model.
+        use_fake_base: Whether to use FakeBaseModel for offline training (default True).
         use_offline_training: Whether to load a memory-efficient model for offline training.
         torch_dtype: dtype to use when loading the model.
         device_map: Device map passed to ``from_pretrained``.
         trust_remote_code: Whether to trust remote code.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/utils.py` around lines 487 - 508, The docstring
for load_vlm_or_llm is missing a description for the use_fake_base parameter;
update the Args section to add a one-line explanation for use_fake_base (what it
toggles and its default behavior), e.g., indicate that use_fake_base controls
whether a FakeBaseModel is used when loading (default True) and how it interacts
with use_offline_training, so readers can understand its purpose alongside
model_name_or_path, use_offline_training, torch_dtype, device_map, and
trust_remote_code.
modelopt/torch/speculative/plugins/modeling_fakebase.py (1)

126-129: Consider using explicit exceptions instead of assert for shape validation.

Assertions can be disabled with -O (optimized mode). For production code that validates external checkpoint data, explicit ValueError or RuntimeError would be more robust.

♻️ Proposed refactor
-        assert lm_head_w.shape == (config.vocab_size, config.hidden_size)
-        assert embed_tokens_w.shape == (config.vocab_size, config.hidden_size)
+        if lm_head_w.shape != (config.vocab_size, config.hidden_size):
+            raise ValueError(
+                f"lm_head shape mismatch: expected {(config.vocab_size, config.hidden_size)}, "
+                f"got {lm_head_w.shape}"
+            )
+        if embed_tokens_w.shape != (config.vocab_size, config.hidden_size):
+            raise ValueError(
+                f"embed_tokens shape mismatch: expected {(config.vocab_size, config.hidden_size)}, "
+                f"got {embed_tokens_w.shape}"
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/modeling_fakebase.py` around lines 126 -
129, Replace the two assert checks with explicit runtime validation that raises
a clear exception (e.g., ValueError) when shapes mismatch: check that
lm_head_w.shape == (config.vocab_size, config.hidden_size) and
embed_tokens_w.shape == (config.vocab_size, config.hidden_size) and if not raise
an error that includes the actual shapes and expected dimensions; keep the
subsequent assignments to self.lm_head.weight.data.copy_(lm_head_w) and
self.embed_tokens.weight.data.copy_(embed_tokens_w) unchanged so that invalid
checkpoint data fails loudly in production rather than being skipped when Python
assertions are disabled.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/examples/speculative_decoding/test_eagle.py`:
- Around line 225-227: The code sets model_id to the hardcoded string
"kimi-k2.5" when model_source is provided, causing remote-model runs to share
the same eagle_output_dir; change the assignment for model_id to derive a unique
identifier from model_source (e.g., use the repository/name suffix:
model_source.split("/")[-1] or a sanitized version of that string) so model_id
is unique per model_source and output_subdir = eagle_output_dir /
f"eagle-{model_id}-offline" writes to distinct directories; update the model_id
assignment near the model_path/model_id definitions used in the test file
(model_path, model_id, eagle_output_dir).

---

Nitpick comments:
In `@modelopt/torch/speculative/plugins/modeling_fakebase.py`:
- Around line 126-129: Replace the two assert checks with explicit runtime
validation that raises a clear exception (e.g., ValueError) when shapes
mismatch: check that lm_head_w.shape == (config.vocab_size, config.hidden_size)
and embed_tokens_w.shape == (config.vocab_size, config.hidden_size) and if not
raise an error that includes the actual shapes and expected dimensions; keep the
subsequent assignments to self.lm_head.weight.data.copy_(lm_head_w) and
self.embed_tokens.weight.data.copy_(embed_tokens_w) unchanged so that invalid
checkpoint data fails loudly in production rather than being skipped when Python
assertions are disabled.

In `@modelopt/torch/speculative/utils.py`:
- Around line 487-508: The docstring for load_vlm_or_llm is missing a
description for the use_fake_base parameter; update the Args section to add a
one-line explanation for use_fake_base (what it toggles and its default
behavior), e.g., indicate that use_fake_base controls whether a FakeBaseModel is
used when loading (default True) and how it interacts with use_offline_training,
so readers can understand its purpose alongside model_name_or_path,
use_offline_training, torch_dtype, device_map, and trust_remote_code.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: d0e27af8-82ec-41a0-a296-d037db3dbd86

📥 Commits

Reviewing files that changed from the base of the PR and between 7b34de6 and 0df75e2.

📒 Files selected for processing (9)
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/launch_train.sh
  • examples/speculative_decoding/main.py
  • examples/speculative_decoding/scripts/ar_validate.py
  • examples/speculative_decoding/scripts/export_hf_checkpoint.py
  • modelopt/torch/speculative/plugins/modeling_fakebase.py
  • modelopt/torch/speculative/plugins/transformers.py
  • modelopt/torch/speculative/utils.py
  • tests/examples/speculative_decoding/test_eagle.py

Copy link
Copy Markdown
Collaborator

@ChenhanYu ChenhanYu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR introduces FakeBaseModel for memory-efficient offline EAGLE training (loads only lm_head + embed_tokens), refactors load_vlm_or_llm_with_kwargsload_vlm_or_llm with a cleaner API, consolidates model path constants, and adds several Kimi-K2.5 compatibility fixes.

Note: load_vlm_or_llm_with_kwargsload_vlm_or_llm is a breaking API change — any downstream callers passing extra kwargs will break.

Test coverage is insufficient. The PR adds one integration test that exercises the happy path, but:

  • FakeBaseModel has no unit tests (weight loading, path auto-detection, error cases like missing safetensors index or unrecognized weight keys)
  • load_vlm_or_llm has no unit tests for its three distinct code paths (fake base, offline with num_hidden_layers=0, normal load)
  • The removed CompressedTensorsConfig ignore patch and attention mask repeat have no regression tests proving they're safe to remove
  • patched_decoder_layer_fwd (past_key_value/past_key_values fix) is untested
  • 3 of 4 test cases download from remote HF repos — needs @pytest.mark.slow or similar for CI reliability

Please add:

  1. Unit tests for FakeBaseModel (local path happy path, missing index, wrong keys)
  2. Unit tests for load_vlm_or_llm code paths
  3. @pytest.mark.slow or network markers for the remote-model tests

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/speculative/plugins/modeling_fakebase.py`:
- Around line 21-27: The module currently hard-imports optional packages
(transformers, huggingface_hub, safetensors) at top-level; change to guarded
imports following the codebase pattern: wrap imports in try/except and set
boolean flags (e.g., HAS_TRANSFORMERS, HAS_HF_HUB, HAS_SAFETENSORS) or move
those imports into the functions/methods that use them (e.g., where
hf_hub_download, EntryNotFoundError, safetensors_load_file, PretrainedConfig,
PreTrainedModel are referenced or where FakeBaseModel is constructed); ensure
any code that requires these libs checks the flags and raises a clear
ImportError if used without the [hf] extras.
- Line 115: The _load_weights() routine currently requires both 'lm_head.weight'
and 'embed_tokens.weight' even when config.tie_word_embeddings is True; update
_load_weights() to check self.config.tie_word_embeddings and, if True, call
_find_weight_key() once for the existing key (prefer 'embed_tokens.weight' but
accept either) and then assign that same tensor to both lm_head and embed_tokens
rather than requiring both keys; keep the existing behavior (separately finding
both keys) when tie_word_embeddings is False and continue to use
_find_weight_key() for key discovery and error reporting.
- Around line 109-115: The config construction and module creation currently
read getattr(base_cfg, "dtype", ...) and omit passing dtype to modules; change
to read getattr(base_cfg, "torch_dtype", None) (falling back to torch.bfloat16
only if None) when creating FakeBaseConfig and then ensure that created modules
(Embedding, Linear, and any parameter tensors) are instantiated with that dtype
so weights are allocated in the checkpoint dtype; update references around
FakeBaseConfig construction and places that instantiate nn.Embedding / nn.Linear
/ torch.nn.Parameter so they pass dtype=self.model.dtype (or the local
torch_dtype) to preserve fp16/bf16 and keep self.model.dtype accurate.

In `@modelopt/torch/speculative/utils.py`:
- Around line 514-516: The code currently reads attributes like
num_hidden_layers and layer_types directly from model_config (created via
transformers.AutoConfig.from_pretrained), but composite VLM configs nest these
under text_config/llm_config; apply the same unwrapping used in FakeBaseModel by
iterating _VLM_CONFIG_ATTRS (or using the same helper) to drill into the actual
LM config before checking hasattr(model_config, "layer_types") and before
assigning num_orig_hidden_layers, so that you first replace model_config with
the unwrapped sub-config (if present) and then proceed to read num_hidden_layers
and layer_types.
- Around line 450-452: In patched_decoder_layer_fwd, avoid clobbering an
existing legacy kwarg by only mapping the new name when present: if
"past_key_values" exists in kwargs, set kwargs["past_key_value"] =
kwargs.pop("past_key_values"); otherwise leave kwargs["past_key_value"]
untouched (do not set it to None). Update the logic in patched_decoder_layer_fwd
(which calls original_decoder_layer_forward) to perform this conditional
translation so callers that pass the old "past_key_value" continue to work.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: d3fe6aa2-e68e-478f-9a0b-7ecfa5ff4b6c

📥 Commits

Reviewing files that changed from the base of the PR and between 0df75e2 and a023e6e.

📒 Files selected for processing (4)
  • examples/speculative_decoding/launch_train.sh
  • modelopt/torch/speculative/plugins/modeling_fakebase.py
  • modelopt/torch/speculative/utils.py
  • tests/examples/speculative_decoding/test_eagle.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • examples/speculative_decoding/launch_train.sh
  • tests/examples/speculative_decoding/test_eagle.py

@h-guo18 h-guo18 requested a review from ChenhanYu March 20, 2026 21:03
Copy link
Copy Markdown
Contributor

@yeyu-nvidia yeyu-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review

Issues

1. Security: rglob("*.pt") widens attack surface (eagle_utils.py:157)

Switching from glob to rglob means recursively discovering .pt files from arbitrary subdirectories. Since .pt uses pickle, a malicious file in a nested dir could execute arbitrary code. Per SECURITY.md best practices, consider using weights_only=True in the corresponding torch.load call, or at least documenting the trust assumption.

2. Shared path constants have new entries that change matching behavior (modeling_fakebase.pytransformers.py)

The new shared lists add entries not present in main:

  • _BASE_MODEL_PATHS adds "language_model.model" as the first entry
  • _EMBED_TOKENS_PATHS adds "embed_tokens" and "language_model.model.embed_tokens" at positions 0 and 1

Since _find_base_model_parts matches the first hit, the new ordering could cause incorrect matches for existing models. For example, a model that has both a top-level embed_tokens (unrelated) and the real one at model.embed_tokens would now match the wrong one. The original ordering in main was carefully chosen — verify the new order doesn't break existing models (Llama, Gemma, etc.).

3. Missing num_orig_hidden_layers on FakeBaseModel path (utils.py:530-533)

When use_offline_training=True and use_fake_base=True, the function returns FakeBaseModel early without setting model.config.num_orig_hidden_layers. The non-fake offline path (line 562) does set this. The EAGLE training loop in main.py previously relied on this attribute. FakeBaseConfig does set num_hidden_layers from the original config, but num_orig_hidden_layers is never set. If downstream code reads num_orig_hidden_layers, this will AttributeError.

4. load_vlm_or_llm_with_kwargs removed without deprecation (utils.py)

The old function is deleted outright. If there are any external consumers (scripts, notebooks, other repos), this is a breaking change. A one-line deprecation wrapper forwarding to load_vlm_or_llm would be low-cost and safer.

5. Removed CompressedTensors quantization patch (transformers.py:577-580)

The CompressedTensorsConfig ignore patch for the EAGLE module was removed without explanation. On main, this prevents the drafter from getting quantized when loading Kimi-K2-Thinking with compressed tensors. If that model is still supported with compressed tensors checkpoints, removing this could cause the EAGLE drafter to be quantized unexpectedly.

6. Removed attention mask .repeat() for Kimi-K2 (transformers.py:730-731)

On main, tensor_mask.repeat(batch_size, 1, 1, 1) is present with the comment "repeat mask for kimi-k2 compatibility". The PR removes it. If Kimi-K2-Thinking still needs this, removing it will break that model's attention. The PR title says "Kimi-K2.5 fixes" — was this intentionally removed because K2.5 doesn't need it, or is K2-Thinking no longer supported?

7. FSDP default behavior change is silent (launch_train.sh:206)

On main, FSDP auto-enables for multi-GPU (TOTAL_GPU > 1). The PR changes this to require --fsdp=True explicitly. Existing users doing multi-GPU training without --fsdp will silently lose FSDP, potentially causing OOM or different training behavior. This should at minimum be documented in the PR description.

8. FakeBaseModel.__init__ signature is non-standard (modeling_fakebase.py:290)

The constructor takes source instead of the standard config parameter expected by PreTrainedModel. This means from_pretrained(), save_pretrained(), and other HF utilities won't work. Fine if intentional, but worth a docstring note that this class is construction-only and not compatible with the standard HF save/load API.

9. Tests download large remote models (test_eagle.py:238-241)

The parametrized test downloads configs/indices from moonshotai/Kimi-K2.5, moonshotai/Kimi-K2-Thinking, and MiniMaxAI/MiniMax-M2.5. These will be slow and require network access. They should be gated (e.g., @pytest.mark.slow or an env var check) to avoid breaking fast CI runs.

Minor/Style

  • modeling_fakebase.py: License header says "Copyright (c) 2024" — should be 2025 for new files.
  • test_eagle.py:641: Missing spaces after commas in test IDs: "kimi-k2.5","kimi-k2-thinking","minimax-m2.5".
  • FakeBaseConfig.dtype stores a torch.dtype object — this won't round-trip through HF's JSON config serialization cleanly.

What looks good

  • FakeBaseModel design is clean — loads only what's needed, auto-detects weight paths
  • load_vlm_or_llm with explicit parameters is much better API than **kwargs
  • The self._base_model_lm_head fix in offline mode is correct and important
  • Kimi decoder past_key_value/past_key_values patch is a clean compatibility fix
  • Single-node GPU detection via torch.cuda.device_count() respecting CUDA_VISIBLE_DEVICES is a good improvement
  • Removing hardcoded trust_remote_code=True in favor of a CLI flag is a security improvement
  • Unit tests are well-structured with proper monkeypatching

@h-guo18
Copy link
Copy Markdown
Contributor Author

h-guo18 commented Mar 25, 2026

Review

Issues

1. Security: rglob("*.pt") widens attack surface (eagle_utils.py:157)

Switching from glob to rglob means recursively discovering .pt files from arbitrary subdirectories. Since .pt uses pickle, a malicious file in a nested dir could execute arbitrary code. Per SECURITY.md best practices, consider using weights_only=True in the corresponding torch.load call, or at least documenting the trust assumption.

2. Shared path constants have new entries that change matching behavior (modeling_fakebase.pytransformers.py)

The new shared lists add entries not present in main:

  • _BASE_MODEL_PATHS adds "language_model.model" as the first entry
  • _EMBED_TOKENS_PATHS adds "embed_tokens" and "language_model.model.embed_tokens" at positions 0 and 1

Since _find_base_model_parts matches the first hit, the new ordering could cause incorrect matches for existing models. For example, a model that has both a top-level embed_tokens (unrelated) and the real one at model.embed_tokens would now match the wrong one. The original ordering in main was carefully chosen — verify the new order doesn't break existing models (Llama, Gemma, etc.).

3. Missing num_orig_hidden_layers on FakeBaseModel path (utils.py:530-533)

When use_offline_training=True and use_fake_base=True, the function returns FakeBaseModel early without setting model.config.num_orig_hidden_layers. The non-fake offline path (line 562) does set this. The EAGLE training loop in main.py previously relied on this attribute. FakeBaseConfig does set num_hidden_layers from the original config, but num_orig_hidden_layers is never set. If downstream code reads num_orig_hidden_layers, this will AttributeError.

4. load_vlm_or_llm_with_kwargs removed without deprecation (utils.py)

The old function is deleted outright. If there are any external consumers (scripts, notebooks, other repos), this is a breaking change. A one-line deprecation wrapper forwarding to load_vlm_or_llm would be low-cost and safer.

5. Removed CompressedTensors quantization patch (transformers.py:577-580)

The CompressedTensorsConfig ignore patch for the EAGLE module was removed without explanation. On main, this prevents the drafter from getting quantized when loading Kimi-K2-Thinking with compressed tensors. If that model is still supported with compressed tensors checkpoints, removing this could cause the EAGLE drafter to be quantized unexpectedly.

6. Removed attention mask .repeat() for Kimi-K2 (transformers.py:730-731)

On main, tensor_mask.repeat(batch_size, 1, 1, 1) is present with the comment "repeat mask for kimi-k2 compatibility". The PR removes it. If Kimi-K2-Thinking still needs this, removing it will break that model's attention. The PR title says "Kimi-K2.5 fixes" — was this intentionally removed because K2.5 doesn't need it, or is K2-Thinking no longer supported?

7. FSDP default behavior change is silent (launch_train.sh:206)

On main, FSDP auto-enables for multi-GPU (TOTAL_GPU > 1). The PR changes this to require --fsdp=True explicitly. Existing users doing multi-GPU training without --fsdp will silently lose FSDP, potentially causing OOM or different training behavior. This should at minimum be documented in the PR description.

8. FakeBaseModel.__init__ signature is non-standard (modeling_fakebase.py:290)

The constructor takes source instead of the standard config parameter expected by PreTrainedModel. This means from_pretrained(), save_pretrained(), and other HF utilities won't work. Fine if intentional, but worth a docstring note that this class is construction-only and not compatible with the standard HF save/load API.

9. Tests download large remote models (test_eagle.py:238-241)

The parametrized test downloads configs/indices from moonshotai/Kimi-K2.5, moonshotai/Kimi-K2-Thinking, and MiniMaxAI/MiniMax-M2.5. These will be slow and require network access. They should be gated (e.g., @pytest.mark.slow or an env var check) to avoid breaking fast CI runs.

Minor/Style

  • modeling_fakebase.py: License header says "Copyright (c) 2024" — should be 2025 for new files.
  • test_eagle.py:641: Missing spaces after commas in test IDs: "kimi-k2.5","kimi-k2-thinking","minimax-m2.5".
  • FakeBaseConfig.dtype stores a torch.dtype object — this won't round-trip through HF's JSON config serialization cleanly.

What looks good

  • FakeBaseModel design is clean — loads only what's needed, auto-detects weight paths
  • load_vlm_or_llm with explicit parameters is much better API than **kwargs
  • The self._base_model_lm_head fix in offline mode is correct and important
  • Kimi decoder past_key_value/past_key_values patch is a clean compatibility fix
  • Single-node GPU detection via torch.cuda.device_count() respecting CUDA_VISIBLE_DEVICES is a good improvement
  • Removing hardcoded trust_remote_code=True in favor of a CLI flag is a security improvement
  • Unit tests are well-structured with proper monkeypatching

Comment 8 addressed in new commits. I think other comments does not make sense. Please let me know if you think they do. Thanks

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (4)
modelopt/torch/speculative/utils.py (1)

450-454: ⚠️ Potential issue | 🟠 Major

Use conditional translation to avoid clobbering existing past_key_value kwarg.

The current implementation unconditionally sets past_key_value to None when past_key_values is absent. This breaks backward compatibility for callers that pass the legacy past_key_value parameter directly. Only translate when the newer key is present.

🐛 Proposed fix
     def patched_decoder_layer_fwd(self, *args, **kwargs):
-        kwargs["past_key_value"] = kwargs.pop("past_key_values", None)
+        if "past_key_values" in kwargs:
+            kwargs.setdefault("past_key_value", kwargs.pop("past_key_values"))
         return original_decoder_layer_forward(self, *args, **kwargs)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/utils.py` around lines 450 - 454,
patched_decoder_layer_fwd currently unconditionally sets
kwargs["past_key_value"] = kwargs.pop("past_key_values", None), which overwrites
callers that pass the legacy past_key_value; change it to only translate when
the new key exists and avoid clobbering an existing legacy value: inside
patched_decoder_layer_fwd check if "past_key_values" in kwargs, and if so, if
"past_key_value" not in kwargs then set kwargs["past_key_value"] =
kwargs.pop("past_key_values") else just kwargs.pop("past_key_values") to remove
the duplicate, then call original_decoder_layer_forward(self, *args, **kwargs);
keep assignment to kimi_k2_module.DeepseekV3DecoderLayer.forward as-is.
modelopt/torch/speculative/plugins/modeling_fakebase.py (3)

197-211: ⚠️ Potential issue | 🟠 Major

Handle tied-word-embedding checkpoints where only one weight key exists.

When config.tie_word_embeddings=True, HuggingFace safetensors serialization deduplicates tied weights, saving only one copy (typically embed_tokens.weight). The current implementation requires both keys, causing valid tied-embedding checkpoints to fail with RuntimeError. Check self.config.tie_word_embeddings and reuse the surviving weight for both modules.

🐛 Proposed fix
     def _load_weights(self, source: str):
         """Load lm_head and embed_tokens weights from a local directory or HuggingFace Hub repo."""
         weight_map = self._load_index(source)

-        lm_head_key = self._find_weight_key(weight_map, _LM_HEAD_PATHS, "lm_head")
-        embed_tokens_key = self._find_weight_key(weight_map, _EMBED_TOKENS_PATHS, "embed_tokens")
-
-        lm_head_path, embed_tokens_path = self._resolve_shard_paths(
-            source, [weight_map[lm_head_key], weight_map[embed_tokens_key]]
-        )
-
-        lm_head_state = safetensors_load_file(lm_head_path, device="cpu")
-        embed_tokens_state = safetensors_load_file(embed_tokens_path, device="cpu")
-
-        return lm_head_state[lm_head_key], embed_tokens_state[embed_tokens_key]
+        if self.config.tie_word_embeddings:
+            # Tied embeddings: only one weight is serialized, try embed_tokens first
+            try:
+                key = self._find_weight_key(weight_map, _EMBED_TOKENS_PATHS, "embed_tokens")
+            except RuntimeError:
+                key = self._find_weight_key(weight_map, _LM_HEAD_PATHS, "lm_head")
+            shard_path = self._resolve_shard_paths(source, [weight_map[key]])[0]
+            state = safetensors_load_file(shard_path, device="cpu")
+            weight = state[key]
+            return weight, weight  # Same tensor for both
+        else:
+            lm_head_key = self._find_weight_key(weight_map, _LM_HEAD_PATHS, "lm_head")
+            embed_tokens_key = self._find_weight_key(weight_map, _EMBED_TOKENS_PATHS, "embed_tokens")
+
+            lm_head_path, embed_tokens_path = self._resolve_shard_paths(
+                source, [weight_map[lm_head_key], weight_map[embed_tokens_key]]
+            )
+
+            lm_head_state = safetensors_load_file(lm_head_path, device="cpu")
+            embed_tokens_state = safetensors_load_file(embed_tokens_path, device="cpu")
+
+            return lm_head_state[lm_head_key], embed_tokens_state[embed_tokens_key]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/modeling_fakebase.py` around lines 197 -
211, The _load_weights flow currently assumes both lm_head and embed_tokens keys
exist; when self.config.tie_word_embeddings is True safetensors may store only
one deduplicated key (typically embed_tokens.weight), causing a RuntimeError.
Update _load_weights to detect missing second key by inspecting weight_map and
self.config.tie_word_embeddings; if one key is absent, resolve shards for the
surviving key and reuse its tensor for both lm_head and embed_tokens outputs
(i.e., set lm_head_state_val = embed_tokens_state_val or vice versa) before
returning. Use the existing helpers (_find_weight_key, _resolve_shard_paths,
safetensors_load_file) and the variables lm_head_key, embed_tokens_key,
weight_map to locate and load the surviving weight and return the same tensor
for both modules when tied embeddings are enabled.

141-141: ⚠️ Potential issue | 🟠 Major

Use torch_dtype attribute instead of dtype to read checkpoint dtype.

HuggingFace PretrainedConfig exposes the checkpoint dtype via torch_dtype, not dtype. The current code reads dtype which doesn't exist, so it always defaults to torch.bfloat16 regardless of the actual checkpoint dtype. This undermines dtype consistency.

🐛 Proposed fix
         config = FakeBaseConfig(
             num_hidden_layers=getattr(base_cfg, "num_hidden_layers", None),
             hidden_size=getattr(base_cfg, "hidden_size", None),
             vocab_size=getattr(base_cfg, "vocab_size", None),
             max_position_embeddings=getattr(base_cfg, "max_position_embeddings", None),
-            dtype=getattr(base_cfg, "dtype", torch.bfloat16),
+            dtype=getattr(base_cfg, "torch_dtype", None) or torch.bfloat16,
             tie_word_embeddings=getattr(base_cfg, "tie_word_embeddings", False),
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/modeling_fakebase.py` at line 141, The
code reads checkpoint dtype using getattr(base_cfg, "dtype", torch.bfloat16)
which is wrong; change it to read getattr(base_cfg, "torch_dtype",
torch.bfloat16) (optionally fallback to "dtype" if you want compatibility) so
the model uses the checkpoint's actual dtype; update the assignment where dtype
is set in modeling_fakebase.py (the line currently using getattr(base_cfg,
"dtype", ...)) to use "torch_dtype" (or a two-step check like
base_cfg.torch_dtype or base_cfg.dtype) to preserve dtype consistency.

21-33: 🛠️ Refactor suggestion | 🟠 Major

Gate optional dependency imports behind try/except or move to function scope.

Per coding guidelines, optional dependencies (transformers, huggingface_hub, safetensors) must be gated. These hard imports at module level will fail if the [hf] extra is not installed, even if FakeBaseModel is never used. The codebase pattern (seen in gradnas.py, etc.) uses try/except with flags or moves imports into functions.

♻️ Proposed fix - move imports to function/method scope
-import torch
-import torch.nn as nn
-import transformers
-from huggingface_hub import hf_hub_download
-from huggingface_hub.errors import EntryNotFoundError
-from safetensors.torch import load_file as safetensors_load_file
-from transformers import (
-    AutoConfig,
-    AutoModel,
-    AutoModelForCausalLM,
-    PretrainedConfig,
-    PreTrainedModel,
-)
+import torch
+import torch.nn as nn
+
+try:
+    import transformers
+    from huggingface_hub import hf_hub_download
+    from huggingface_hub.errors import EntryNotFoundError
+    from safetensors.torch import load_file as safetensors_load_file
+    from transformers import (
+        AutoConfig,
+        AutoModel,
+        AutoModelForCausalLM,
+        PretrainedConfig,
+        PreTrainedModel,
+    )
+    _HAS_HF_DEPS = True
+except ImportError:
+    _HAS_HF_DEPS = False

Then check _HAS_HF_DEPS in from_source and raise a clear error if dependencies are missing.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/modeling_fakebase.py` around lines 21 -
33, Top-level imports for optional HF deps (transformers, huggingface_hub,
safetensors) must be gated; move those imports into the function/method scope or
wrap them in a try/except that sets a _HAS_HF_DEPS flag. Specifically, remove or
relocate the module-level imports in modeling_fakebase.py and either (a) import
AutoConfig/AutoModel/AutoModelForCausalLM/PretrainedConfig/PreTrainedModel,
hf_hub_download, EntryNotFoundError and safetensors_load_file inside
FakeBaseModel.from_source (or any helper it calls) or (b) wrap the top-level
imports in try/except to set _HAS_HF_DEPS=False and then at the start of
FakeBaseModel.from_source check _HAS_HF_DEPS and raise a clear error if missing;
reference FakeBaseModel and its from_source method when applying the change.
🧹 Nitpick comments (1)
tests/examples/speculative_decoding/test_eagle.py (1)

231-233: Consider using attribute-based VLM detection instead of hardcoded model name.

The current check only handles moonshotai/Kimi-K2.5 but other parametrized models (Kimi-K2-Thinking, MiniMax-M2.5) may also be VLMs requiring config unwrapping. This mirrors the pattern in FakeBaseModel.from_source which uses _VLM_CONFIG_ATTRS.

♻️ Proposed fix
     cfg = transformers.AutoConfig.from_pretrained(model_path, trust_remote_code=True)

-    if model_source=="moonshotai/Kimi-K2.5":
-        `#vlm`, get text config
-        cfg = cfg.text_config
+    # For VLMs, unwrap to the language model config
+    for attr in ["text_config", "llm_config"]:
+        if getattr(cfg, attr, None) is not None:
+            cfg = getattr(cfg, attr)
+            break
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/examples/speculative_decoding/test_eagle.py` around lines 231 - 233,
The code currently checks model_source=="moonshotai/Kimi-K2.5" to decide to
unwrap a VLM config and set cfg = cfg.text_config; instead, detect VLMs by
attribute presence like FakeBaseModel.from_source does using _VLM_CONFIG_ATTRS:
check for any attribute in _VLM_CONFIG_ATTRS on the model source or cfg object
and if present set cfg = cfg.text_config (or the corresponding attr) so other
parametrized VLMs (e.g., Kimi-K2-Thinking, MiniMax-M2.5) are handled without
hardcoding model names; update the conditional around model_source and cfg to
use that attribute-based test and reuse the same attribute list used by
FakeBaseModel.from_source.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/speculative/utils.py`:
- Around line 514-538: The offline-training branch reads num_hidden_layers and
layer_types directly from model_config which fails for composite VLM configs;
before inspecting those attributes (in the block that sets extra and after
loading model_config) unwrap VLM configs the same way FakeBaseModel.from_source
does (use the same _VLM_CONFIG_ATTRS/text_config or llm_config lookup) so you
pull num_hidden_layers and layer_types from the nested text/llm config when
present; then set extra["num_hidden_layers"], extra["layer_types"] and later
model.config.num_orig_hidden_layers from the unwrapped values rather than the
top-level model_config fields.

---

Duplicate comments:
In `@modelopt/torch/speculative/plugins/modeling_fakebase.py`:
- Around line 197-211: The _load_weights flow currently assumes both lm_head and
embed_tokens keys exist; when self.config.tie_word_embeddings is True
safetensors may store only one deduplicated key (typically embed_tokens.weight),
causing a RuntimeError. Update _load_weights to detect missing second key by
inspecting weight_map and self.config.tie_word_embeddings; if one key is absent,
resolve shards for the surviving key and reuse its tensor for both lm_head and
embed_tokens outputs (i.e., set lm_head_state_val = embed_tokens_state_val or
vice versa) before returning. Use the existing helpers (_find_weight_key,
_resolve_shard_paths, safetensors_load_file) and the variables lm_head_key,
embed_tokens_key, weight_map to locate and load the surviving weight and return
the same tensor for both modules when tied embeddings are enabled.
- Line 141: The code reads checkpoint dtype using getattr(base_cfg, "dtype",
torch.bfloat16) which is wrong; change it to read getattr(base_cfg,
"torch_dtype", torch.bfloat16) (optionally fallback to "dtype" if you want
compatibility) so the model uses the checkpoint's actual dtype; update the
assignment where dtype is set in modeling_fakebase.py (the line currently using
getattr(base_cfg, "dtype", ...)) to use "torch_dtype" (or a two-step check like
base_cfg.torch_dtype or base_cfg.dtype) to preserve dtype consistency.
- Around line 21-33: Top-level imports for optional HF deps (transformers,
huggingface_hub, safetensors) must be gated; move those imports into the
function/method scope or wrap them in a try/except that sets a _HAS_HF_DEPS
flag. Specifically, remove or relocate the module-level imports in
modeling_fakebase.py and either (a) import
AutoConfig/AutoModel/AutoModelForCausalLM/PretrainedConfig/PreTrainedModel,
hf_hub_download, EntryNotFoundError and safetensors_load_file inside
FakeBaseModel.from_source (or any helper it calls) or (b) wrap the top-level
imports in try/except to set _HAS_HF_DEPS=False and then at the start of
FakeBaseModel.from_source check _HAS_HF_DEPS and raise a clear error if missing;
reference FakeBaseModel and its from_source method when applying the change.

In `@modelopt/torch/speculative/utils.py`:
- Around line 450-454: patched_decoder_layer_fwd currently unconditionally sets
kwargs["past_key_value"] = kwargs.pop("past_key_values", None), which overwrites
callers that pass the legacy past_key_value; change it to only translate when
the new key exists and avoid clobbering an existing legacy value: inside
patched_decoder_layer_fwd check if "past_key_values" in kwargs, and if so, if
"past_key_value" not in kwargs then set kwargs["past_key_value"] =
kwargs.pop("past_key_values") else just kwargs.pop("past_key_values") to remove
the duplicate, then call original_decoder_layer_forward(self, *args, **kwargs);
keep assignment to kimi_k2_module.DeepseekV3DecoderLayer.forward as-is.

---

Nitpick comments:
In `@tests/examples/speculative_decoding/test_eagle.py`:
- Around line 231-233: The code currently checks
model_source=="moonshotai/Kimi-K2.5" to decide to unwrap a VLM config and set
cfg = cfg.text_config; instead, detect VLMs by attribute presence like
FakeBaseModel.from_source does using _VLM_CONFIG_ATTRS: check for any attribute
in _VLM_CONFIG_ATTRS on the model source or cfg object and if present set cfg =
cfg.text_config (or the corresponding attr) so other parametrized VLMs (e.g.,
Kimi-K2-Thinking, MiniMax-M2.5) are handled without hardcoding model names;
update the conditional around model_source and cfg to use that attribute-based
test and reuse the same attribute list used by FakeBaseModel.from_source.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: d60a3a31-ec3e-4aea-8a42-f8fc76dd86af

📥 Commits

Reviewing files that changed from the base of the PR and between 99946d5 and c118d94.

📒 Files selected for processing (4)
  • modelopt/torch/speculative/plugins/modeling_fakebase.py
  • modelopt/torch/speculative/utils.py
  • tests/examples/speculative_decoding/test_eagle.py
  • tests/unit/torch/speculative/plugins/test_fakebase.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/unit/torch/speculative/plugins/test_fakebase.py

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@h-guo18 h-guo18 force-pushed the haoguo/fakebasemodel branch from c118d94 to cb6802c Compare March 26, 2026 04:00
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Mar 26, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1052/

Built to branch gh-pages at 2026-03-26 21:07 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/speculative_decoding/launch_train.sh (1)

220-226: ⚠️ Potential issue | 🟡 Minor

Document FSDP behavior change.

FSDP is now disabled by default and requires explicit --fsdp=True even for multi-GPU setups. Previously, FSDP was auto-enabled when TOTAL_GPU > 1. This is a breaking change for existing multi-GPU training workflows.

Consider documenting this change in the PR description or a changelog entry.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/launch_train.sh` around lines 220 - 226, Update
the PR description or project changelog to note that FSDP is now disabled by
default and must be set with --fsdp=True even on multi-GPU systems (previously
auto-enabled when TOTAL_GPU > 1); specifically mention the behavior around the
launch script variables TOTAL_GPU, FSDP, and FSDP_ARGS (including the use of
fsdp_config.json) and the potential breaking impact on multi-GPU workflows, and
optionally add a short warning in launch_train.sh when TOTAL_GPU>1 and FSDP is
not True so users see the change at runtime.
♻️ Duplicate comments (3)
modelopt/torch/speculative/utils.py (1)

450-452: ⚠️ Potential issue | 🟡 Minor

Preserve existing past_key_value kwarg.

The current implementation unconditionally sets past_key_value from past_key_values, potentially overwriting an existing past_key_value kwarg with None when past_key_values is absent. Only translate when past_key_values is actually present.

🔧 Proposed fix
     def patched_decoder_layer_fwd(self, *args, **kwargs):
-        kwargs["past_key_value"] = kwargs.pop("past_key_values", None)
+        if "past_key_values" in kwargs:
+            kwargs.setdefault("past_key_value", kwargs.pop("past_key_values"))
         return original_decoder_layer_forward(self, *args, **kwargs)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/utils.py` around lines 450 - 452,
patched_decoder_layer_fwd currently always sets kwargs["past_key_value"] =
kwargs.pop("past_key_values", None) which can overwrite an existing
past_key_value with None; change it to only transfer past_key_values when that
key exists (e.g., if "past_key_values" in kwargs: kwargs["past_key_value"] =
kwargs.pop("past_key_values")) before calling
original_decoder_layer_forward(self, *args, **kwargs); preserve any existing
past_key_value when past_key_values is absent.
modelopt/torch/speculative/plugins/modeling_fakebase.py (2)

141-141: ⚠️ Potential issue | 🟡 Minor

Use torch_dtype instead of dtype from config.

AutoConfig.from_pretrained() exposes the checkpoint dtype via torch_dtype, not dtype. This will always fall back to torch.bfloat16 regardless of the actual checkpoint dtype.

-            dtype=getattr(base_cfg, "dtype", torch.bfloat16),
+            dtype=getattr(base_cfg, "torch_dtype", None) or torch.bfloat16,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/modeling_fakebase.py` at line 141, The
config uses getattr(base_cfg, "dtype", torch.bfloat16) which is incorrect
because AutoConfig exposes the checkpoint dtype as torch_dtype; update the call
in modeling_fakebase.py (where base_cfg is used to set dtype) to use
getattr(base_cfg, "torch_dtype", torch.bfloat16) so the actual checkpoint dtype
is respected while keeping torch.bfloat16 as the fallback.

197-211: ⚠️ Potential issue | 🟠 Major

Handle tied-word-embedding checkpoints.

When config.tie_word_embeddings=True, HuggingFace safetensors typically saves only one copy of the tied weights (usually embed_tokens.weight). The current _load_weights() unconditionally requires both keys, which will cause RuntimeError for valid tied-embedding checkpoints.

Check self.config.tie_word_embeddings and reuse the surviving weight key for both tensors when True.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/modeling_fakebase.py` around lines 197 -
211, The _load_weights method currently assumes both lm_head and embed_tokens
keys exist; update it to handle tied embeddings by checking
self.config.tie_word_embeddings and, if True, detect when one key is missing and
reuse the present key for both tensors (e.g., if only embed_tokens_key exists,
set lm_head_key = embed_tokens_key, or vice versa) before calling
_resolve_shard_paths, safetensors_load_file, and returning the pair; adjust
logic around _find_weight_key, lm_head_key, embed_tokens_key,
_resolve_shard_paths, and the final return so the same loaded tensor is returned
for both lm_head and embed_tokens when tie_word_embeddings is True.
🧹 Nitpick comments (3)
modelopt/torch/speculative/plugins/modeling_fakebase.py (1)

1-1: Update copyright year to 2025.

The copyright year should be 2025 for new files added in 2025/2026.

-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/modeling_fakebase.py` at line 1, Update
the SPDX header year from 2024 to 2025 at the top of the file: change the
copyright notice in the SPDX comment (the first line in modeling_fakebase.py) so
it reads 2025 instead of 2024.
tests/examples/speculative_decoding/test_eagle.py (2)

216-216: Minor: Add consistent spacing in test IDs.

-    ids=["tinyllama", "kimi-k2.5","kimi-k2-thinking","minimax-m2.5"],
+    ids=["tinyllama", "kimi-k2.5", "kimi-k2-thinking", "minimax-m2.5"],
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/examples/speculative_decoding/test_eagle.py` at line 216, In the
tests/examples/speculative_decoding/test_eagle.py file update the ids list
passed in the test (the ids=["tinyllama",
"kimi-k2.5","kimi-k2-thinking","minimax-m2.5"] entry) to use consistent spacing
after commas so each item is separated by ", " (e.g. "tinyllama", "kimi-k2.5",
"kimi-k2-thinking", "minimax-m2.5") for consistent style.

274-308: Consider using pytest.mark.dependency or fixtures for test ordering.

test_offline_resume_training_kimi depends on test_offline_eagle3_training["kimi-k2.5"] having run first to create the checkpoint at eagle-Kimi-K2.5-offline. While pytest typically runs tests in file order, this implicit dependency is fragile.

Consider:

  1. Using pytest-depends with @pytest.mark.depends(on=...), or
  2. Converting the initial training into a session-scoped fixture that the resume test depends on.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/examples/speculative_decoding/test_eagle.py` around lines 274 - 308,
The resume test test_offline_resume_training_kimi has an implicit dependency on
test_offline_eagle3_training["kimi-k2.5"] creating the checkpoint; make this
explicit by either (A) adding a dependency marker like
`@pytest.mark.depends`(on="test_offline_eagle3_training[kimi-k2.5]") above
test_offline_resume_training_kimi, or (B) refactoring the initial training into
a session-scoped fixture (e.g., eagle_kimi_checkpoint_fixture) that returns the
checkpoint_dir and then inject that fixture into
test_offline_resume_training_kimi (replace direct use of
eagle_output_dir/eagle-Kimi-K2.5-offline), ensuring the resume test uses the
fixture so pytest guarantees order and availability of the checkpoint.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@examples/speculative_decoding/launch_train.sh`:
- Around line 220-226: Update the PR description or project changelog to note
that FSDP is now disabled by default and must be set with --fsdp=True even on
multi-GPU systems (previously auto-enabled when TOTAL_GPU > 1); specifically
mention the behavior around the launch script variables TOTAL_GPU, FSDP, and
FSDP_ARGS (including the use of fsdp_config.json) and the potential breaking
impact on multi-GPU workflows, and optionally add a short warning in
launch_train.sh when TOTAL_GPU>1 and FSDP is not True so users see the change at
runtime.

---

Duplicate comments:
In `@modelopt/torch/speculative/plugins/modeling_fakebase.py`:
- Line 141: The config uses getattr(base_cfg, "dtype", torch.bfloat16) which is
incorrect because AutoConfig exposes the checkpoint dtype as torch_dtype; update
the call in modeling_fakebase.py (where base_cfg is used to set dtype) to use
getattr(base_cfg, "torch_dtype", torch.bfloat16) so the actual checkpoint dtype
is respected while keeping torch.bfloat16 as the fallback.
- Around line 197-211: The _load_weights method currently assumes both lm_head
and embed_tokens keys exist; update it to handle tied embeddings by checking
self.config.tie_word_embeddings and, if True, detect when one key is missing and
reuse the present key for both tensors (e.g., if only embed_tokens_key exists,
set lm_head_key = embed_tokens_key, or vice versa) before calling
_resolve_shard_paths, safetensors_load_file, and returning the pair; adjust
logic around _find_weight_key, lm_head_key, embed_tokens_key,
_resolve_shard_paths, and the final return so the same loaded tensor is returned
for both lm_head and embed_tokens when tie_word_embeddings is True.

In `@modelopt/torch/speculative/utils.py`:
- Around line 450-452: patched_decoder_layer_fwd currently always sets
kwargs["past_key_value"] = kwargs.pop("past_key_values", None) which can
overwrite an existing past_key_value with None; change it to only transfer
past_key_values when that key exists (e.g., if "past_key_values" in kwargs:
kwargs["past_key_value"] = kwargs.pop("past_key_values")) before calling
original_decoder_layer_forward(self, *args, **kwargs); preserve any existing
past_key_value when past_key_values is absent.

---

Nitpick comments:
In `@modelopt/torch/speculative/plugins/modeling_fakebase.py`:
- Line 1: Update the SPDX header year from 2024 to 2025 at the top of the file:
change the copyright notice in the SPDX comment (the first line in
modeling_fakebase.py) so it reads 2025 instead of 2024.

In `@tests/examples/speculative_decoding/test_eagle.py`:
- Line 216: In the tests/examples/speculative_decoding/test_eagle.py file update
the ids list passed in the test (the ids=["tinyllama",
"kimi-k2.5","kimi-k2-thinking","minimax-m2.5"] entry) to use consistent spacing
after commas so each item is separated by ", " (e.g. "tinyllama", "kimi-k2.5",
"kimi-k2-thinking", "minimax-m2.5") for consistent style.
- Around line 274-308: The resume test test_offline_resume_training_kimi has an
implicit dependency on test_offline_eagle3_training["kimi-k2.5"] creating the
checkpoint; make this explicit by either (A) adding a dependency marker like
`@pytest.mark.depends`(on="test_offline_eagle3_training[kimi-k2.5]") above
test_offline_resume_training_kimi, or (B) refactoring the initial training into
a session-scoped fixture (e.g., eagle_kimi_checkpoint_fixture) that returns the
checkpoint_dir and then inject that fixture into
test_offline_resume_training_kimi (replace direct use of
eagle_output_dir/eagle-Kimi-K2.5-offline), ensuring the resume test uses the
fixture so pytest guarantees order and availability of the checkpoint.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 8d810a1a-2446-4f38-ba0c-8b9fa350b134

📥 Commits

Reviewing files that changed from the base of the PR and between c118d94 and cb6802c.

📒 Files selected for processing (10)
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/launch_train.sh
  • examples/speculative_decoding/main.py
  • examples/speculative_decoding/scripts/ar_validate.py
  • examples/speculative_decoding/scripts/export_hf_checkpoint.py
  • modelopt/torch/speculative/plugins/modeling_fakebase.py
  • modelopt/torch/speculative/plugins/transformers.py
  • modelopt/torch/speculative/utils.py
  • tests/examples/speculative_decoding/test_eagle.py
  • tests/unit/torch/speculative/plugins/test_fakebase.py
✅ Files skipped from review due to trivial changes (1)
  • tests/unit/torch/speculative/plugins/test_fakebase.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/scripts/export_hf_checkpoint.py
  • examples/speculative_decoding/scripts/ar_validate.py
  • examples/speculative_decoding/main.py

h-guo18 added 5 commits March 26, 2026 05:27
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants