Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions examples/speculative_decoding/launch_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,16 @@ set -x

SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
NUM_NODES=${NUM_NODES:-1}
GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)}
TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE))
echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)"
if [[ "$NUM_NODES" != 1 ]]; then
#Multi Node Training
GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)}
TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE))
echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)"
else
#Single Node Training, GPU can be specified by $CUDA_VISIBLE_DEVICES
TOTAL_GPU=$(python -c "import torch; print(torch.cuda.device_count())")
echo "Total GPUs: $TOTAL_GPU (Single Node Training)"
fi
# Calculate save_steps
DEFAULT_SAVE_STEPS=$((8192 / TOTAL_GPU))

Expand Down Expand Up @@ -180,15 +187,17 @@ else
VLM_ARGS=""
fi

FSDP_ARGS=""
if [[ "$TOTAL_GPU" -gt 1 ]]; then
#Use FSDP2 when multi GPU available
FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json"
else
#Otherwise, single GPU training
FSDP_ARGS=""
# Use FSDP when multi GPU available, default to FSDP1
FSDP_ARGS="$FSDP_ARGS --fsdp 'full_shard'"
TRANSFORMERS_5=$(python -c "from packaging.version import Version; import transformers; print(Version(transformers.__version__) >= Version('5.0'))" 2>/dev/null)
if [[ "$TRANSFORMERS_5" == "True" ]]; then
# For transformers >= 5.0, use FSDP2
FSDP_ARGS="$FSDP_ARGS --fsdp_config ${SCRIPT_DIR}/fsdp_config.json"
fi
fi


if [[ "$DRAFT_VOCAB_CACHE" != "" ]]; then
DRAFT_VOCAB_CACHE_ARGS="--draft_vocab_cache $DRAFT_VOCAB_CACHE"
else
Expand Down
10 changes: 7 additions & 3 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
patch_ring_attention_for_ttt,
)
from medusa_utils import make_medusa_supervised_data_module
from packaging.version import Version
from transformers.trainer_utils import get_last_checkpoint

import modelopt.torch.opt as mto
Expand All @@ -54,6 +55,8 @@
)
from modelopt.torch.utils import print_rank_0

TRANSFORMERS_VERSION = Version(transformers.__version__)

torch.manual_seed(0)
mto.enable_huggingface_checkpointing()

Expand Down Expand Up @@ -142,9 +145,10 @@ def train():
model_args, data_args, training_args, medusa_args, eagle_args = (
parser.parse_args_into_dataclasses()
)
training_args.parallelism_config = ParallelismConfig(
cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size
)
if Version("5.0") <= TRANSFORMERS_VERSION:
training_args.parallelism_config = ParallelismConfig(
cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size
)
if training_args.cp_size > 1:
patch_ring_attention_for_ttt()
# Specific patch to accelerate 1.12.0. Removable after move to 1.13.0
Expand Down
83 changes: 83 additions & 0 deletions modelopt/torch/speculative/plugins/hf_model_patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A registry of target model-specific patches on class HFEagleModel.

The patches are indexed by model type and will be called by HFEagleModel object at the end of modify().
"""

import types
from collections.abc import Callable

import transformers
from packaging.version import Version
from transformers.utils.quantization_config import CompressedTensorsConfig

all = ["apply_model_patch"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Typo: all should be __all__.

The module-level all variable has no special meaning in Python. It should be __all__ (with double underscores) to properly control what gets exported when using from module import *.

Fix
-all = ["apply_model_patch"]
+__all__ = ["apply_model_patch"]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_model_patches.py` at line 28, Module
defines a module-level export list as all = ["apply_model_patch"] which is a
typo; rename it to __all__ = ["apply_model_patch"] so Python's import mechanism
recognizes the intended exported symbol (apply_model_patch). Update the variable
name in the module where apply_model_patch is defined to use double-underscore
__all__ to control from module import * exports.


_MODEL_PATCH_REGISTRY: dict[str, Callable] = {}


def register_model_patch(model_type: str):
"""Decorator to register a patch function for a specific model type."""

def decorator(func: Callable):
_MODEL_PATCH_REGISTRY[model_type] = func
return func

return decorator


def apply_model_patch(module):
"""Apply a registered patch to the given module based on model_type."""
model_type = getattr(module.config, "model_type", None)
if model_type in _MODEL_PATCH_REGISTRY:
_MODEL_PATCH_REGISTRY[model_type](module)


@register_model_patch("kimi_k2")
def patch_for_kimi_k2(module):
"""Patch for Kimi-K2-Thinking as target model.

- Version check for transformers < 5.0
- Avoid quantizing drafter by updating quantization_config
- Repeat attention mask at batch dimension
"""
if Version(transformers.__version__) >= Version("5.0"):
raise RuntimeError(
"Kimi K2 is not supported for transformers >= 5.0. \
Please install transformers >=4.57, <5.0"
)

if module.eagle_config._attn_implementation == "flex_attention":
raise ValueError("Kimi K2 does not support flex attention.")

# Avoid quantizing drafter by updating quantization_config
quant_config = getattr(module.config, "quantization_config", None)
if isinstance(quant_config, CompressedTensorsConfig):
quant_config.quantization_config.ignore.append("re:.*eagle_module.*")

# Kimi K2 assert attention mask shape as (bsz, 1, qlen, kvlen)
# https://huggingface.co/moonshotai/Kimi-K2-Thinking/blob/main/modeling_deepseek.py#L829
# So we repeat the attention mask at batch dimension
original_func = module._compute_ttt_attention_mask

def _patched_compute_ttt_attention_mask(self, batch_size, seq_length, ttt_step):
tensor_mask = original_func(batch_size, seq_length, ttt_step)
return tensor_mask.repeat(batch_size, 1, 1, 1)

module._compute_ttt_attention_mask = types.MethodType(
_patched_compute_ttt_attention_mask, module
)
12 changes: 4 additions & 8 deletions modelopt/torch/speculative/plugins/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
)
from transformers.trainer_pt_utils import LabelSmoother
from transformers.utils import ModelOutput
from transformers.utils.quantization_config import CompressedTensorsConfig

from ..eagle.conversion import EagleDMRegistry
from ..eagle.eagle_model import EagleModel
Expand All @@ -63,6 +62,7 @@
get_ttt_msk_func,
temporary_set_config_value,
)
from .hf_model_patches import apply_model_patch

__all__ = ["HFARValidation", "HFEagleModel", "HFMedusaModel"]

Expand Down Expand Up @@ -584,11 +584,6 @@ def modify(
if self.eagle_config._attn_implementation is None:
self.eagle_config._attn_implementation = "sdpa"

# Patch for Kimi-K2-Thinking, avoid quantizing drafter
quant_config = getattr(self.config, "quantization_config", None)
if isinstance(quant_config, CompressedTensorsConfig):
quant_config.ignore.append("re:.*eagle_module.*")

# Set default aux_hidden_state layers
if (
self.eagle_config.use_aux_hidden_state
Expand Down Expand Up @@ -628,6 +623,9 @@ def modify(
self.num_ttt_steps = 4 # NOTE: (hg) hardcoded for now. Might add to config later.
self._cached_attn_blk_masks = {}

# Apply model-specific patch if needed
apply_model_patch(self)

def _get_ttt_attention_mask(self, batch_size, seq_length, ttt_step):
# compile and cached flex attention masks in first call
if ttt_step not in self._cached_attn_blk_masks:
Expand Down Expand Up @@ -750,8 +748,6 @@ def _compute_ttt_attention_mask(
tensor_mask, 0, dtype=self._base_llm_config.dtype, device=self.device
).masked_fill(~tensor_mask, dtypemin)

# Note: (hg) repeat mask for kimi-k2 compatibility
tensor_mask = tensor_mask.repeat(batch_size, 1, 1, 1)
return tensor_mask

def _base_model_forward(
Expand Down
12 changes: 10 additions & 2 deletions modelopt/torch/speculative/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,11 +422,9 @@ def _setup_kimi_k2_decoder():
# Import required modules
config_module_path = os.path.join(kimi_k2_path, "configuration_deepseek.py")
model_module_path = os.path.join(kimi_k2_path, "modeling_deepseek.py")

_import_module_from_path(
config_module_path, f"{KIMI_K2_PACKAGE_NAME}.configuration_deepseek", KIMI_K2_PACKAGE_NAME
)

kimi_k2_module = _import_module_from_path(
model_module_path, f"{KIMI_K2_PACKAGE_NAME}.modeling_deepseek", KIMI_K2_PACKAGE_NAME
)
Expand All @@ -443,6 +441,16 @@ def patched_fwd_with_lazy_rope_init(self, *args, **kwargs):
kimi_k2_module.DeepseekV3Attention._init_rope = lambda self: None
kimi_k2_module.DeepseekV3Attention.forward = patched_fwd_with_lazy_rope_init

# Kimi implementation is based on older transformers which use "past_key_value" argument
# We patch it to "past_key_values" for compatibility
original_decoder_layer_forward = kimi_k2_module.DeepseekV3DecoderLayer.forward

def patched_decoder_layer_fwd(self, *args, **kwargs):
kwargs["past_key_value"] = kwargs.get("past_key_values")
return original_decoder_layer_forward(self, *args, **kwargs)

kimi_k2_module.DeepseekV3DecoderLayer.forward = patched_decoder_layer_fwd

return getattr(kimi_k2_module, "DeepseekV3DecoderLayer")


Expand Down
Loading