diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 8c96a19a7..1bc7df981 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -70,7 +70,15 @@ def __len__(self): return len(self.dumped_files) def __getitem__(self, i) -> dict[str, torch.Tensor]: - offline_data = torch.load(self.dumped_files[i]) + try: + offline_data = torch.load(self.dumped_files[i]) + except Exception as e: + print( + f"[ERROR] Failed to load file at index={i}, " + f"path='{self.dumped_files[i]}', error={e}. " + "Reusing data from previous index (i-1)." + ) + return self.__getitem__(i - 1) labels = torch.full_like(offline_data["input_ids"], IGNORE_TOKEN_ID) labels[..., :-1] = offline_data["input_ids"][..., 1:] @@ -154,7 +162,7 @@ def make_eagle_supervised_data_module( assert not data_args.vlm_processor, "Offline data is not supported for VLM." offline_data_path = Path(data_args.offline_data_path) - dumped_files = [str(p) for p in offline_data_path.glob("*.pt")] + dumped_files = [str(p) for p in offline_data_path.rglob("*.pt")] if not dumped_files: raise ValueError(f"No .pt files found in {data_args.offline_data_path}") diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 0ffe17486..c15b97bda 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -122,6 +122,18 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi DISABLE_TORCH_COMPILE="${1#*=}" ;; + --use_fake_base_for_offline*) + if [[ "$1" != *=* ]]; then shift; fi + USE_FAKE_BASE_FOR_OFFLINE="${1#*=}" + ;; + --trust_remote_code*) + if [[ "$1" != *=* ]]; then shift; fi + TRUST_REMOTE_CODE="${1#*=}" + ;; + --fsdp*) + if [[ "$1" != *=* ]]; then shift; fi + FSDP="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -134,9 +146,16 @@ set -x SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" NUM_NODES=${NUM_NODES:-1} -GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} -TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) -echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)" +if [[ "$NUM_NODES" != 1 ]]; then + #Multi Node Training + GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} + TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) + echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)" +else + #Single Node Training, GPU can be specified by $CUDA_VISIBLE_DEVICES + TOTAL_GPU=$(python -c "import torch; print(torch.cuda.device_count())") + echo "Total GPUs: $TOTAL_GPU (Single Node Training)" +fi # Calculate save_steps DEFAULT_SAVE_STEPS=$((8192 / TOTAL_GPU)) @@ -151,6 +170,7 @@ SAVE_STEPS=${SAVE_STEPS:-$DEFAULT_SAVE_STEPS} LR=${LR:-"1e-4"} TRAIN_BS=${TRAIN_BS:-1} TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048} +DATA=${DATA:-""} OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""} DISABLE_TQDM=${DISABLE_TQDM:-False} VLM_PROCESSOR=${VLM_PROCESSOR:-} @@ -165,6 +185,9 @@ MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"} DISABLE_TORCH_COMPILE=${DISABLE_TORCH_COMPILE:-"False"} NUM_TTT_STEPS=${NUM_TTT_STEPS:-3} +USE_FAKE_BASE_FOR_OFFLINE=${USE_FAKE_BASE_FOR_OFFLINE:-"False"} +TRUST_REMOTE_CODE=${TRUST_REMOTE_CODE:-"False"} +FSDP=${FSDP:-"False"} if [[ "$MODE" == "eagle3" ]]; then if [[ -n "$EAGLE_CONFIG" ]]; then @@ -182,10 +205,10 @@ if [[ "$OFFLINE_DATA_PATH" != "" ]]; then echo "Offline data path $OFFLINE_DATA_PATH does not exist or is not a directory." exit 1 else - OFFLINE_TRAINING_ARGS="--offline-data-path $OFFLINE_DATA_PATH --ar_validate_steps -1" + DATA_ARGS="--offline-data-path $OFFLINE_DATA_PATH --ar_validate_steps -1" fi else - OFFLINE_TRAINING_ARGS="" + DATA_ARGS="--data_path $DATA" fi @@ -195,7 +218,7 @@ else VLM_ARGS="" fi -if [[ "$TOTAL_GPU" -gt 1 ]]; then +if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then #Use FSDP2 when multi GPU available FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json" else @@ -245,15 +268,16 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai --lr_scheduler_type linear \ --logging_steps $LOG_STEPS \ --tf32 True \ - --data_path $DATA \ + $DATA_ARGS \ --disable_tqdm $DISABLE_TQDM \ --estimate_ar $ESTIMATE_AR \ --ar_validate_steps $AR_VALIDATE_STEPS \ --mix_hidden_states $MIX_HIDDEN_STATES \ --disable_torch_compile $DISABLE_TORCH_COMPILE \ + --use_fake_base_for_offline $USE_FAKE_BASE_FOR_OFFLINE \ + --trust_remote_code $TRUST_REMOTE_CODE \ $DRAFT_VOCAB_CACHE_ARGS \ $VLM_ARGS \ - $OFFLINE_TRAINING_ARGS \ $SPECULATIVE_ARGS \ $FSDP_ARGS \ --cp_size $CP_SIZE \ diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 0db3867cc..3369d399c 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -47,10 +47,7 @@ import modelopt.torch.opt as mto import modelopt.torch.speculative as mtsp -from modelopt.torch.speculative.utils import ( - load_vlm_or_llm_with_kwargs, - patch_transformers5_params_loading, -) +from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading from modelopt.torch.utils import print_rank_0 torch.manual_seed(0) @@ -60,11 +57,18 @@ @dataclass class ModelArguments: model_name_or_path: str | None = field(default="TinyLlama/TinyLlama-1.1B-Chat-v1.0") + use_fake_base_for_offline: bool = field( + default=False, metadata={"help": "Whether to use fake base for offline training."} + ) + trust_remote_code: bool = field( + default=False, metadata={"help": "Whether to trust remote code."} + ) @dataclass class DataArguments: data_path: str = field( + default=None, metadata={"help": "Path to the training data."}, ) eval_data_path: str = field(default=None, metadata={"help": "Path to the evaluation data."}) @@ -153,6 +157,8 @@ def train(): model_args, data_args, training_args, medusa_args, eagle_args = ( parser.parse_args_into_dataclasses() ) + if not data_args.data_path and not data_args.offline_data_path: + raise ValueError("Either data_path or offline_data_path must be provided.") if training_args.cp_size > 1 or training_args.dp_shard_size > 1: training_args.parallelism_config = ParallelismConfig( cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size @@ -178,29 +184,27 @@ def train(): if checkpoint: with patch_transformers5_params_loading(): - _, model = load_vlm_or_llm_with_kwargs( - checkpoint, torch_dtype="auto", trust_remote_code=True + model = load_vlm_or_llm( + checkpoint, torch_dtype="auto", trust_remote_code=model_args.trust_remote_code ) - tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) + tokenizer = transformers.AutoTokenizer.from_pretrained( + checkpoint, trust_remote_code=model_args.trust_remote_code + ) else: # To avoid OOM for large models, we load and convert model on CPU first. # Model will be moved to GPU during HF trainer.init(). - offline_kwargs = {"num_hidden_layers": 0} if use_offline_training else {} - model_config, model = load_vlm_or_llm_with_kwargs( + model = load_vlm_or_llm( model_args.model_name_or_path, + use_fake_base=model_args.use_fake_base_for_offline, + use_offline_training=use_offline_training, torch_dtype="auto", device_map="cpu", - trust_remote_code=True, - **offline_kwargs, + trust_remote_code=model_args.trust_remote_code, ) - if use_offline_training: - # When doing offline training, we need to set num_hidden_layers - # since we override it when loading the model for space savings - model.config.num_orig_hidden_layers = model_config.num_hidden_layers tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, model_max_length=training_args.training_seq_len, - trust_remote_code=True, + trust_remote_code=model_args.trust_remote_code, ) if training_args.mode == "medusa": config = { diff --git a/examples/speculative_decoding/guides/CR2_eagle_config.json b/examples/speculative_decoding/recipes/CR2_eagle_config.json similarity index 100% rename from examples/speculative_decoding/guides/CR2_eagle_config.json rename to examples/speculative_decoding/recipes/CR2_eagle_config.json diff --git a/examples/speculative_decoding/guides/nemotron_mapping.bin b/examples/speculative_decoding/recipes/nemotron_mapping.bin similarity index 100% rename from examples/speculative_decoding/guides/nemotron_mapping.bin rename to examples/speculative_decoding/recipes/nemotron_mapping.bin diff --git a/examples/speculative_decoding/guides/train_eagle_head_cosmos_reason2.ipynb b/examples/speculative_decoding/recipes/train_eagle_head_cosmos_reason2.ipynb similarity index 100% rename from examples/speculative_decoding/guides/train_eagle_head_cosmos_reason2.ipynb rename to examples/speculative_decoding/recipes/train_eagle_head_cosmos_reason2.ipynb diff --git a/examples/speculative_decoding/scripts/ar_validate.py b/examples/speculative_decoding/scripts/ar_validate.py index d5c37a895..d1bf31a1a 100644 --- a/examples/speculative_decoding/scripts/ar_validate.py +++ b/examples/speculative_decoding/scripts/ar_validate.py @@ -22,7 +22,7 @@ import modelopt.torch.opt as mto from modelopt.torch.speculative.plugins.transformers import HFARValidation -from modelopt.torch.speculative.utils import load_vlm_or_llm_with_kwargs +from modelopt.torch.speculative.utils import load_vlm_or_llm mto.enable_huggingface_checkpointing() @@ -72,7 +72,7 @@ def main(): accelerator = Accelerator() # Load model and tokenizer - _, model = load_vlm_or_llm_with_kwargs(args.model_path, device_map="auto") + model = load_vlm_or_llm(args.model_path, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(args.model_path) model.eval() model = accelerator.prepare(model) diff --git a/examples/speculative_decoding/scripts/export_hf_checkpoint.py b/examples/speculative_decoding/scripts/export_hf_checkpoint.py index 23a7560f7..925f4b73d 100644 --- a/examples/speculative_decoding/scripts/export_hf_checkpoint.py +++ b/examples/speculative_decoding/scripts/export_hf_checkpoint.py @@ -21,7 +21,7 @@ import modelopt.torch.opt as mto from modelopt.torch.export import export_speculative_decoding -from modelopt.torch.speculative.utils import load_vlm_or_llm_with_kwargs +from modelopt.torch.speculative.utils import load_vlm_or_llm def parse_args(): @@ -38,7 +38,7 @@ def parse_args(): mto.enable_huggingface_checkpointing() args = parse_args() -_, model = load_vlm_or_llm_with_kwargs(args.model_path, torch_dtype="auto") +model = load_vlm_or_llm(args.model_path, torch_dtype="auto") model.eval() with torch.inference_mode(): export_speculative_decoding( diff --git a/modelopt/torch/speculative/plugins/modeling_fakebase.py b/modelopt/torch/speculative/plugins/modeling_fakebase.py new file mode 100644 index 000000000..4ed06ed64 --- /dev/null +++ b/modelopt/torch/speculative/plugins/modeling_fakebase.py @@ -0,0 +1,221 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lightweight fake base model for offline speculative decoding training.""" + +import json +import os + +import torch +import torch.nn as nn +import transformers +from huggingface_hub import hf_hub_download +from huggingface_hub.errors import EntryNotFoundError +from safetensors.torch import load_file as safetensors_load_file +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + PretrainedConfig, + PreTrainedModel, +) + +# Candidate module paths searched in order — shared with HFEagleModel._find_base_model_parts +_EMBED_TOKENS_PATHS = [ + "embed_tokens", + "language_model.model.embed_tokens", + "model.embed_tokens", + "backbone.embeddings", + "language_model.backbone.embeddings", + "model.language_model.embed_tokens", +] +_LM_HEAD_PATHS = ["lm_head", "language_model.lm_head"] +_BASE_MODEL_PATHS = [ + "language_model.model", + "model.language_model", + "model", + "backbone", + "language_model.backbone", +] +_VLM_CONFIG_ATTRS = ["text_config", "llm_config"] +_SAFETENSORS_INDEX_FILENAME = "model.safetensors.index.json" + + +class FakeBaseConfig(PretrainedConfig): + """Minimal config for FakeBaseModel that supports offline speculative decoding training.""" + + model_type = "fake_base_model" + + def __init__( + self, + num_hidden_layers=None, + hidden_size=None, + vocab_size=None, + max_position_embeddings=None, + dtype=torch.bfloat16, + tie_word_embeddings=False, + **kwargs, + ): + """Initialize FakeBaseConfig with minimal model configuration parameters.""" + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + self.num_hidden_layers = num_hidden_layers + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + if isinstance(dtype, str): + dtype = getattr(torch, dtype) + self.dtype = dtype + + +class FakeBaseModel(PreTrainedModel): + """Minimal base model for offline speculative decoding. + + Contains only ``lm_head``, ``embed_tokens``, and the minimal config needed by the EAGLE + training loop. The full model weights are never loaded, keeping memory usage low. + + Weights are loaded from a local HuggingFace checkpoint directory. Weight key names and + VLM config nesting are auto-detected from the shared path constants. + """ + + config_class = FakeBaseConfig + + def __init__(self, config: FakeBaseConfig, **kwargs): + """Initialize FakeBaseModel structure from a FakeBaseConfig. + + To construct a FakeBaseModel from an original HuggingFace checkpoint (e.g. a Llama + repo), use the :meth:`from_source` classmethod instead. + """ + super().__init__(config, **kwargs) + # Initialize dummy module and attributes for compatibility with HFEagleModel + self.model = nn.Module() + self.model.layers = nn.ModuleList() + self.model.dtype = config.dtype + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=config.dtype) + self.lm_head = nn.Linear( + config.hidden_size, config.vocab_size, bias=False, dtype=config.dtype + ) + # Initialize weights and apply final processing + self.post_init() + + @classmethod + def from_source(cls, source: str, trust_remote_code: bool = False) -> "FakeBaseModel": + """Load lm_head and embed_tokens from a local directory or HuggingFace Hub repo. + + Args: + source: Path to a local HuggingFace checkpoint directory, or a HuggingFace Hub + repo ID (e.g. ``"meta-llama/Llama-3.1-8B"``). The source type is detected + automatically: if ``source`` is an existing local directory it is treated as a + local checkpoint; otherwise it is treated as a Hub repo ID and the required + files are downloaded via ``huggingface_hub``. + """ + orig_config = transformers.AutoConfig.from_pretrained( + source, trust_remote_code=trust_remote_code + ) + # For vlms, detect language model config based on _VLM_CONFIG_ATTRS + base_cfg = next( + ( + getattr(orig_config, attr) + for attr in _VLM_CONFIG_ATTRS + if getattr(orig_config, attr, None) is not None + ), + orig_config, + ) + # Extract necessary info for spec training from base config + config = FakeBaseConfig( + num_hidden_layers=getattr(base_cfg, "num_hidden_layers", None), + hidden_size=getattr(base_cfg, "hidden_size", None), + vocab_size=getattr(base_cfg, "vocab_size", None), + max_position_embeddings=getattr(base_cfg, "max_position_embeddings", None), + dtype=getattr(base_cfg, "dtype", torch.bfloat16), + tie_word_embeddings=getattr(base_cfg, "tie_word_embeddings", False), + ) + model = cls(config) + # Load lm_head and embed_tokens only from checkpoint + lm_head_w, embed_tokens_w = model._load_weights(source) + assert lm_head_w.shape == (config.vocab_size, config.hidden_size) + assert embed_tokens_w.shape == (config.vocab_size, config.hidden_size) + model.lm_head.weight.data.copy_(lm_head_w) + model.embed_tokens.weight.data.copy_(embed_tokens_w) + return model + + @staticmethod + def _find_weight_key(weight_map: dict, paths: list[str], label: str) -> str: + """Return the first ``path + '.weight'`` found in ``weight_map``.""" + for path in paths: + key = path + ".weight" + if key in weight_map: + return key + tried = [p + ".weight" for p in paths] + raise RuntimeError(f"Cannot find {label} in checkpoint; tried: {tried}") + + @staticmethod + def _load_index(source: str) -> dict: + """Load weight_map from model.safetensors.index.json (local directory or Hub repo).""" + if os.path.isdir(source): + index_path = os.path.join(source, _SAFETENSORS_INDEX_FILENAME) + if not os.path.isfile(index_path): + raise FileNotFoundError( + f"No {_SAFETENSORS_INDEX_FILENAME} found in {source!r}. " + "FakeBaseModel only supports safetensors checkpoints. " + "Checkpoints using pytorch_model.bin or single-file formats are not supported." + ) + else: + try: + index_path = hf_hub_download(repo_id=source, filename=_SAFETENSORS_INDEX_FILENAME) + except EntryNotFoundError: + raise ValueError( + f"Repository {source!r} does not contain {_SAFETENSORS_INDEX_FILENAME}. " + "FakeBaseModel only supports safetensors checkpoints. " + "Checkpoints using pytorch_model.bin or single-file formats are not supported." + ) from None + with open(index_path) as f: + return json.load(f).get("weight_map", {}) + + @staticmethod + def _resolve_shard_paths(source: str, shard_filenames: list[str]) -> list[str]: + """Return local filesystem paths for each shard filename. + + For a local directory the paths are joined directly; for a HuggingFace Hub repo ID the + shards are downloaded via ``hf_hub_download`` (cached on subsequent calls). + """ + if os.path.isdir(source): + return [os.path.join(source, name) for name in shard_filenames] + return [hf_hub_download(repo_id=source, filename=name) for name in shard_filenames] + + def _load_weights(self, source: str): + """Load lm_head and embed_tokens weights from a local directory or HuggingFace Hub repo.""" + weight_map = self._load_index(source) + + lm_head_key = self._find_weight_key(weight_map, _LM_HEAD_PATHS, "lm_head") + embed_tokens_key = self._find_weight_key(weight_map, _EMBED_TOKENS_PATHS, "embed_tokens") + + lm_head_path, embed_tokens_path = self._resolve_shard_paths( + source, [weight_map[lm_head_key], weight_map[embed_tokens_key]] + ) + + lm_head_state = safetensors_load_file(lm_head_path, device="cpu") + embed_tokens_state = safetensors_load_file(embed_tokens_path, device="cpu") + + return lm_head_state[lm_head_key], embed_tokens_state[embed_tokens_key] + + def forward(self, *args, **kwargs): + """Not implemented: FakeBaseModel omits full model weights and cannot run inference.""" + raise NotImplementedError("FakeBaseModel forward is not implemented.") + + +# Register so that AutoConfig / AutoModel / AutoModelForCausalLM can resolve "fake_base_model". +AutoConfig.register("fake_base_model", FakeBaseConfig) +AutoModel.register(FakeBaseConfig, FakeBaseModel) +AutoModelForCausalLM.register(FakeBaseConfig, FakeBaseModel) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 25946f2c1..8561a390f 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -46,7 +46,6 @@ ) from transformers.trainer_pt_utils import LabelSmoother from transformers.utils import ModelOutput -from transformers.utils.quantization_config import CompressedTensorsConfig from ...export.plugins.hf_spec_export import ( EagleExporter, @@ -66,6 +65,7 @@ get_ttt_msk_func, temporary_set_config_value, ) +from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS __all__ = ["HFARValidation", "HFEagleModel", "HFMedusaModel"] @@ -475,19 +475,9 @@ def get_exporter(self) -> SpeculativeDecodingExporter: def _find_base_model_parts(self): """Find model parts from different models and set base_{part}_path attributes.""" base_model_parts_mapping = { - "base_model_path": [ - "model.language_model", - "model", - "backbone", - "language_model.backbone", - ], - "base_model_embeddings_path": [ - "model.embed_tokens", - "backbone.embeddings", - "language_model.backbone.embeddings", - "model.language_model.embed_tokens", - ], - "base_model_lm_head_path": ["lm_head", "language_model.lm_head"], + "base_model_path": _BASE_MODEL_PATHS, + "base_model_embeddings_path": _EMBED_TOKENS_PATHS, + "base_model_lm_head_path": _LM_HEAD_PATHS, } for name, paths in base_model_parts_mapping.items(): @@ -588,11 +578,6 @@ def modify( if self.eagle_config._attn_implementation is None: self.eagle_config._attn_implementation = "sdpa" - # Patch for Kimi-K2-Thinking, avoid quantizing drafter - quant_config = getattr(self.config, "quantization_config", None) - if isinstance(quant_config, CompressedTensorsConfig): - quant_config.ignore.append("re:.*eagle_module.*") - # Set default aux_hidden_state layers if ( self.eagle_config.use_aux_hidden_state @@ -602,7 +587,7 @@ def modify( # Freeze all parameters if self.eagle_freeze_base_model: - for name, param in self.named_parameters(): + for _, param in self.named_parameters(): param.requires_grad = False self.eagle_module = EagleModule( @@ -785,8 +770,6 @@ def _compute_ttt_attention_mask( tensor_mask, 0, dtype=self._base_llm_config.dtype, device=self.device ).masked_fill(~tensor_mask, dtypemin) - # Note: (hg) repeat mask for kimi-k2 compatibility - tensor_mask = tensor_mask.repeat(batch_size, 1, 1, 1) return tensor_mask def _base_model_forward( @@ -912,7 +895,7 @@ def forward( assert "base_model_outputs" in kwargs base_outputs = EagleBaseModelOutput.from_offline_dict(kwargs["base_model_outputs"]) if base_outputs.logits is None: - base_outputs.logits = self.lm_head(base_outputs.out_hiddens) + base_outputs.logits = self._base_model_lm_head(base_outputs.out_hiddens) past_key_values = None else: with self._nvtx_range("base_model_forward"): diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index 72c5b5dbc..9e167c8dc 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -443,6 +443,16 @@ def patched_fwd_with_lazy_rope_init(self, *args, **kwargs): kimi_k2_module.DeepseekV3Attention._init_rope = lambda self: None kimi_k2_module.DeepseekV3Attention.forward = patched_fwd_with_lazy_rope_init + # Kimi implementation is based on older transformers which use "past_key_value" argument + # We patch it to "past_key_values" for compatibility + original_decoder_layer_forward = kimi_k2_module.DeepseekV3DecoderLayer.forward + + def patched_decoder_layer_fwd(self, *args, **kwargs): + kwargs["past_key_value"] = kwargs.pop("past_key_values", None) + return original_decoder_layer_forward(self, *args, **kwargs) + + kimi_k2_module.DeepseekV3DecoderLayer.forward = patched_decoder_layer_fwd + return getattr(kimi_k2_module, "DeepseekV3DecoderLayer") @@ -474,21 +484,60 @@ def enable_cp_ttt_patch(): modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = False -def load_vlm_or_llm_with_kwargs(model_name_or_path: str, **kwargs): - """Load a VLM or LLM with kwargs. Returns the model and model config.""" +def load_vlm_or_llm( + model_name_or_path: str, + use_fake_base: bool = False, + use_offline_training: bool = False, + torch_dtype: str | torch.dtype | None = None, + device_map: str | None = None, + trust_remote_code: bool = False, +): + """Load a VLM or LLM. Returns the model. + + When ``use_offline_training=True``, returns a + :class:`~modelopt.torch.speculative.plugins.modeling_fakebase.FakeBaseModel` containing only + ``lm_head`` and ``embed_tokens``, auto-detecting weight paths from the checkpoint. + Otherwise, falls back to loading with ``num_hidden_layers=0`` for memory efficiency. + + Args: + model_name_or_path: Local path or HuggingFace repo ID of the model. + use_offline_training: Whether to load a memory-efficient model for offline training. + torch_dtype: dtype to use when loading the model. + device_map: Device map passed to ``from_pretrained``. + trust_remote_code: Whether to trust remote code. + """ + if use_offline_training and use_fake_base: + from modelopt.torch.speculative.plugins.modeling_fakebase import FakeBaseModel + + return FakeBaseModel.from_source(model_name_or_path, trust_remote_code=trust_remote_code) + model_config = transformers.AutoConfig.from_pretrained( - model_name_or_path, trust_remote_code=True + model_name_or_path, trust_remote_code=trust_remote_code ) if "vl" in model_config.model_type.lower(): model_cls = transformers.AutoModelForVision2Seq else: model_cls = transformers.AutoModelForCausalLM - if kwargs.get("num_hidden_layers") == 0: + extra = {} + if use_offline_training: + extra["num_hidden_layers"] = 0 if hasattr(model_config, "layer_types"): - kwargs["layer_types"] = [] + extra["layer_types"] = [] + + model = model_cls.from_pretrained( + model_name_or_path, + trust_remote_code=trust_remote_code, + torch_dtype=torch_dtype, + device_map=device_map, + **extra, + ) + + if use_offline_training: + # Preserve the original layer count since we loaded with num_hidden_layers=0 + model.config.num_orig_hidden_layers = model_config.num_hidden_layers - return model_config, model_cls.from_pretrained(model_name_or_path, **kwargs) + return model @contextlib.contextmanager @@ -504,10 +553,12 @@ def patch_transformers5_params_loading(): """ # Skip patching for non-applicable transformers version if importlib.util.find_spec("transformers.core_model_loading") is None: + yield return from transformers import core_model_loading if not hasattr(core_model_loading, "set_param_for_module"): + yield return orig_set_param_for_module = core_model_loading.set_param_for_module diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index a3542fa25..271241bcb 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -15,6 +15,7 @@ import json import os +from pathlib import Path import pytest import safetensors.torch @@ -25,6 +26,44 @@ from modelopt.torch.export.plugins.hf_spec_export import LLAMA_EAGLE_SINGLE_LAYER +def generate_offline_pt_data( + output_dir, + num_files: int = 8, + seq_len: int = 128, + hidden_size: int = 512, + vocab_size: int = 32000, + num_aux_layers: int = 2, +) -> Path: + """Generate fake offline training .pt files for EAGLE3 offline training tests. + + Each file contains the keys expected by OfflineSupervisedDataset: + - input_ids: LongTensor of shape (seq_len,) + - hidden_states: FloatTensor of shape (seq_len, hidden_size) + - aux_hidden_states: FloatTensor of shape (seq_len, hidden_size*num_aux_layers) + + Args: + output_dir: Directory to write .pt files into. + num_files: Number of .pt files to generate. + seq_len: Sequence length. Defaults to 128. + hidden_size: Hidden size matching the base model. Defaults to 512 (tiny_llama). + vocab_size: Vocabulary size matching the base model. Defaults to 32000 (tiny_llama). + num_aux_layers: Number of auxiliary layers. Defaults to 2. + Returns: + Path to the output directory. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + torch.manual_seed(42) + for i in range(num_files): + sample = { + "input_ids": torch.randint(0, vocab_size, (seq_len,)), + "hidden_states": torch.randn(seq_len, hidden_size), + "aux_hidden_states": torch.randn(seq_len, hidden_size * num_aux_layers), + } + torch.save(sample, output_dir / f"sample_{i:04d}.pt") + return output_dir + + @pytest.fixture(scope="module") def eagle_output_dir(tmp_path_factory): """Eagle output directory shared in this module.""" @@ -164,3 +203,106 @@ def test_convert_to_vllm_ckpt(tiny_llama_path, eagle_output_dir): ], "speculative_decoding", ) + + +@pytest.mark.parametrize( + ("model_source", "use_fake_base"), + [ + (None, False), # tiny_llama (from fixture), no FakeBase + ("moonshotai/Kimi-K2.5", True), # remote HF repo, FakeBaseModel + ("moonshotai/Kimi-K2-Thinking", True), # remote HF repo, no FakeBaseModel + ("MiniMaxAI/MiniMax-M2.5", True), + ], + ids=["tinyllama", "kimi-k2.5","kimi-k2-thinking","minimax-m2.5"], +) +def test_offline_eagle3_training( + tiny_llama_path, tiny_daring_anteater_path, tmp_path, eagle_output_dir, + model_source, use_fake_base, +): + """Test Eagle3 training with pre-computed hidden states (offline mode / FakeBaseModel).""" + import transformers + + model_path = tiny_llama_path if model_source is None else model_source + model_id = "tinyllama" if model_source is None else model_source.split("/")[-1] + output_subdir = eagle_output_dir / f"eagle-{model_id}-offline" + + cfg = transformers.AutoConfig.from_pretrained(model_path, trust_remote_code=True) + + if model_source=="moonshotai/Kimi-K2.5": + #vlm, get text config + cfg = cfg.text_config + + offline_data_dir = generate_offline_pt_data( + tmp_path / "offline_data", + hidden_size=cfg.hidden_size, + vocab_size=cfg.vocab_size, + num_aux_layers=min(cfg.num_hidden_layers, 3), + ) + + tiny_eagle_config = { + "max_position_embeddings": 128, + "num_hidden_layers": 1, + "intermediate_size": 64, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "head_dim": 64, + } + config_file = tmp_path / "tiny_eagle_config_offline.json" + with open(config_file, "w") as f: + json.dump(tiny_eagle_config, f) + + cmd = [ + "./launch_train.sh", + "--model", model_path, + "--data", tiny_daring_anteater_path, + "--offline-data", offline_data_dir, + "--num_epochs", "0.1", + "--lr", "1e-5", + "--mode", "eagle3", + "--eagle_config", str(config_file), + "--output_dir", output_subdir, + "--training_seq_len", "64", + "--trust_remote_code", "True", + "--fsdp", "False", + ] + if use_fake_base: + cmd += ["--use_fake_base_for_offline", "true"] + run_example_command(cmd, "speculative_decoding") + assert os.path.exists(output_subdir / "config.json") + + +def test_offline_resume_training_kimi(tiny_daring_anteater_path, tmp_path, eagle_output_dir): + """Test resume of offline Eagle3 training from a FakeBaseModel checkpoint (Kimi-K2.5). + + Depends on test_offline_eagle3_training["kimi-k2.5"] having run first. + Exercises AutoModelForCausalLM.from_pretrained with model_type='fake_base_model'. + """ + import transformers + + checkpoint_dir = eagle_output_dir / "eagle-Kimi-K2.5-offline" + config = transformers.AutoConfig.from_pretrained(checkpoint_dir, trust_remote_code=True) + + offline_data_dir = generate_offline_pt_data( + tmp_path / "offline_data_resume", + hidden_size=config.hidden_size, + vocab_size=config.vocab_size, + num_aux_layers=min(config.num_hidden_layers, 3), + ) + + run_example_command( + [ + "./launch_train.sh", + "--model", checkpoint_dir, + "--data", tiny_daring_anteater_path, + "--offline-data", offline_data_dir, + "--num_epochs", "0.2", + "--lr", "1e-5", + "--mode", "eagle3", + "--output_dir", checkpoint_dir, + "--training_seq_len", "64", + "--trust_remote_code", "True", + "--fsdp", "False", + "--use_fake_base_for_offline", "true", + ], + "speculative_decoding", + ) diff --git a/tests/unit/torch/speculative/plugins/test_fakebase.py b/tests/unit/torch/speculative/plugins/test_fakebase.py new file mode 100644 index 000000000..4698dd402 --- /dev/null +++ b/tests/unit/torch/speculative/plugins/test_fakebase.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for FakeBaseModel and the fake-base / offline paths in load_vlm_or_llm.""" + +import json + +import pytest +import safetensors.torch +import torch + +pytest.importorskip("transformers") +import transformers + +from modelopt.torch.speculative.plugins.modeling_fakebase import FakeBaseModel +from modelopt.torch.speculative.utils import load_vlm_or_llm + +_HIDDEN_SIZE = 16 +_VOCAB_SIZE = 32 + + +@pytest.fixture +def fake_config(monkeypatch): + """Monkeypatch AutoConfig.from_pretrained to return a minimal fake config.""" + cfg = transformers.PretrainedConfig() + cfg.model_type = "llama" + cfg.hidden_size = _HIDDEN_SIZE + cfg.vocab_size = _VOCAB_SIZE + cfg.num_hidden_layers = 2 + cfg.max_position_embeddings = 128 + cfg.tie_word_embeddings = False + monkeypatch.setattr(transformers.AutoConfig, "from_pretrained", lambda *a, **kw: cfg) + return cfg + + +@pytest.fixture +def fake_checkpoint(tmp_path, fake_config): + """Minimal local safetensors checkpoint loadable by FakeBaseModel.""" + tensors = { + "lm_head.weight": torch.zeros(_VOCAB_SIZE, _HIDDEN_SIZE), + "embed_tokens.weight": torch.zeros(_VOCAB_SIZE, _HIDDEN_SIZE), + } + shard = tmp_path / "model-00001-of-00001.safetensors" + safetensors.torch.save_file(tensors, shard) + index = {"weight_map": dict.fromkeys(tensors, shard.name)} + (tmp_path / "model.safetensors.index.json").write_text(json.dumps(index)) + return tmp_path + + +def test_fakebase_local_happy_path(fake_checkpoint): + model = FakeBaseModel.from_source(str(fake_checkpoint)) + assert model.lm_head.weight.shape == torch.Size([_VOCAB_SIZE, _HIDDEN_SIZE]) + assert model.embed_tokens.weight.shape == torch.Size([_VOCAB_SIZE, _HIDDEN_SIZE]) + + +def test_fakebase_missing_index_raises(tmp_path, fake_config): + with pytest.raises(FileNotFoundError, match="safetensors"): + FakeBaseModel.from_source(str(tmp_path)) + + +def test_load_vlm_or_llm_returns_fakebase(fake_checkpoint): + model = load_vlm_or_llm(str(fake_checkpoint), use_offline_training=True, use_fake_base=True) + assert isinstance(model, FakeBaseModel) + + +def test_load_vlm_or_llm_offline_zero_layers(monkeypatch): + cfg = transformers.PretrainedConfig() + cfg.model_type = "llama" + cfg.num_hidden_layers = 4 + monkeypatch.setattr(transformers.AutoConfig, "from_pretrained", lambda *a, **kw: cfg) + + captured_kwargs = {} + + class _FakeModel: + config = cfg + + def _fake_from_pretrained(*args, **kwargs): + captured_kwargs.update(kwargs) + return _FakeModel() + + monkeypatch.setattr(transformers.AutoModelForCausalLM, "from_pretrained", _fake_from_pretrained) + + model = load_vlm_or_llm("fake-model", use_offline_training=True, use_fake_base=False) + assert captured_kwargs.get("num_hidden_layers") == 0 + assert model.config.num_orig_hidden_layers == 4