diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index bdfc4ee383..ace19d0c31 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -122,9 +122,16 @@ set -x SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" NUM_NODES=${NUM_NODES:-1} -GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} -TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) -echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)" +if [[ "$NUM_NODES" != 1 ]]; then + #Multi Node Training + GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} + TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) + echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)" +else + #Single Node Training, GPU can be specified by $CUDA_VISIBLE_DEVICES + TOTAL_GPU=$(python -c "import torch; print(torch.cuda.device_count())") + echo "Total GPUs: $TOTAL_GPU (Single Node Training)" +fi # Calculate save_steps DEFAULT_SAVE_STEPS=$((8192 / TOTAL_GPU)) @@ -180,15 +187,17 @@ else VLM_ARGS="" fi +FSDP_ARGS="" if [[ "$TOTAL_GPU" -gt 1 ]]; then - #Use FSDP2 when multi GPU available - FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json" -else - #Otherwise, single GPU training - FSDP_ARGS="" + # Use FSDP when multi GPU available, default to FSDP1 + FSDP_ARGS="$FSDP_ARGS --fsdp 'full_shard'" + TRANSFORMERS_5=$(python -c "from packaging.version import Version; import transformers; print(Version(transformers.__version__) >= Version('5.0'))" 2>/dev/null) + if [[ "$TRANSFORMERS_5" == "True" ]]; then + # For transformers >= 5.0, use FSDP2 + FSDP_ARGS="$FSDP_ARGS --fsdp_config ${SCRIPT_DIR}/fsdp_config.json" + fi fi - if [[ "$DRAFT_VOCAB_CACHE" != "" ]]; then DRAFT_VOCAB_CACHE_ARGS="--draft_vocab_cache $DRAFT_VOCAB_CACHE" else diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 6821111849..ca0cf4c30c 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -44,6 +44,7 @@ patch_ring_attention_for_ttt, ) from medusa_utils import make_medusa_supervised_data_module +from packaging.version import Version from transformers.trainer_utils import get_last_checkpoint import modelopt.torch.opt as mto @@ -54,6 +55,8 @@ ) from modelopt.torch.utils import print_rank_0 +TRANSFORMERS_VERSION = Version(transformers.__version__) + torch.manual_seed(0) mto.enable_huggingface_checkpointing() @@ -142,9 +145,10 @@ def train(): model_args, data_args, training_args, medusa_args, eagle_args = ( parser.parse_args_into_dataclasses() ) - training_args.parallelism_config = ParallelismConfig( - cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size - ) + if Version("5.0") <= TRANSFORMERS_VERSION: + training_args.parallelism_config = ParallelismConfig( + cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size + ) if training_args.cp_size > 1: patch_ring_attention_for_ttt() # Specific patch to accelerate 1.12.0. Removable after move to 1.13.0 diff --git a/modelopt/torch/speculative/plugins/hf_model_patches.py b/modelopt/torch/speculative/plugins/hf_model_patches.py new file mode 100644 index 0000000000..72e7a5d67e --- /dev/null +++ b/modelopt/torch/speculative/plugins/hf_model_patches.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A registry of target model-specific patches on class HFEagleModel. + +The patches are indexed by model type and will be called by HFEagleModel object at the end of modify(). +""" + +import types +from collections.abc import Callable + +import transformers +from packaging.version import Version +from transformers.utils.quantization_config import CompressedTensorsConfig + +all = ["apply_model_patch"] + +_MODEL_PATCH_REGISTRY: dict[str, Callable] = {} + + +def register_model_patch(model_type: str): + """Decorator to register a patch function for a specific model type.""" + + def decorator(func: Callable): + _MODEL_PATCH_REGISTRY[model_type] = func + return func + + return decorator + + +def apply_model_patch(module): + """Apply a registered patch to the given module based on model_type.""" + model_type = getattr(module.config, "model_type", None) + if model_type in _MODEL_PATCH_REGISTRY: + _MODEL_PATCH_REGISTRY[model_type](module) + + +@register_model_patch("kimi_k2") +def patch_for_kimi_k2(module): + """Patch for Kimi-K2-Thinking as target model. + + - Version check for transformers < 5.0 + - Avoid quantizing drafter by updating quantization_config + - Repeat attention mask at batch dimension + """ + if Version(transformers.__version__) >= Version("5.0"): + raise RuntimeError( + "Kimi K2 is not supported for transformers >= 5.0. \ + Please install transformers >=4.57, <5.0" + ) + + if module.eagle_config._attn_implementation == "flex_attention": + raise ValueError("Kimi K2 does not support flex attention.") + + # Avoid quantizing drafter by updating quantization_config + quant_config = getattr(module.config, "quantization_config", None) + if isinstance(quant_config, CompressedTensorsConfig): + quant_config.quantization_config.ignore.append("re:.*eagle_module.*") + + # Kimi K2 assert attention mask shape as (bsz, 1, qlen, kvlen) + # https://huggingface.co/moonshotai/Kimi-K2-Thinking/blob/main/modeling_deepseek.py#L829 + # So we repeat the attention mask at batch dimension + original_func = module._compute_ttt_attention_mask + + def _patched_compute_ttt_attention_mask(self, batch_size, seq_length, ttt_step): + tensor_mask = original_func(batch_size, seq_length, ttt_step) + return tensor_mask.repeat(batch_size, 1, 1, 1) + + module._compute_ttt_attention_mask = types.MethodType( + _patched_compute_ttt_attention_mask, module + ) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 5e7ff9c8e7..2856dab028 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -48,7 +48,6 @@ ) from transformers.trainer_pt_utils import LabelSmoother from transformers.utils import ModelOutput -from transformers.utils.quantization_config import CompressedTensorsConfig from ..eagle.conversion import EagleDMRegistry from ..eagle.eagle_model import EagleModel @@ -63,6 +62,7 @@ get_ttt_msk_func, temporary_set_config_value, ) +from .hf_model_patches import apply_model_patch __all__ = ["HFARValidation", "HFEagleModel", "HFMedusaModel"] @@ -584,11 +584,6 @@ def modify( if self.eagle_config._attn_implementation is None: self.eagle_config._attn_implementation = "sdpa" - # Patch for Kimi-K2-Thinking, avoid quantizing drafter - quant_config = getattr(self.config, "quantization_config", None) - if isinstance(quant_config, CompressedTensorsConfig): - quant_config.ignore.append("re:.*eagle_module.*") - # Set default aux_hidden_state layers if ( self.eagle_config.use_aux_hidden_state @@ -628,6 +623,9 @@ def modify( self.num_ttt_steps = 4 # NOTE: (hg) hardcoded for now. Might add to config later. self._cached_attn_blk_masks = {} + # Apply model-specific patch if needed + apply_model_patch(self) + def _get_ttt_attention_mask(self, batch_size, seq_length, ttt_step): # compile and cached flex attention masks in first call if ttt_step not in self._cached_attn_blk_masks: @@ -750,8 +748,6 @@ def _compute_ttt_attention_mask( tensor_mask, 0, dtype=self._base_llm_config.dtype, device=self.device ).masked_fill(~tensor_mask, dtypemin) - # Note: (hg) repeat mask for kimi-k2 compatibility - tensor_mask = tensor_mask.repeat(batch_size, 1, 1, 1) return tensor_mask def _base_model_forward( diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index 72c5b5dbc0..856704cffe 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -422,11 +422,9 @@ def _setup_kimi_k2_decoder(): # Import required modules config_module_path = os.path.join(kimi_k2_path, "configuration_deepseek.py") model_module_path = os.path.join(kimi_k2_path, "modeling_deepseek.py") - _import_module_from_path( config_module_path, f"{KIMI_K2_PACKAGE_NAME}.configuration_deepseek", KIMI_K2_PACKAGE_NAME ) - kimi_k2_module = _import_module_from_path( model_module_path, f"{KIMI_K2_PACKAGE_NAME}.modeling_deepseek", KIMI_K2_PACKAGE_NAME ) @@ -443,6 +441,16 @@ def patched_fwd_with_lazy_rope_init(self, *args, **kwargs): kimi_k2_module.DeepseekV3Attention._init_rope = lambda self: None kimi_k2_module.DeepseekV3Attention.forward = patched_fwd_with_lazy_rope_init + # Kimi implementation is based on older transformers which use "past_key_value" argument + # We patch it to "past_key_values" for compatibility + original_decoder_layer_forward = kimi_k2_module.DeepseekV3DecoderLayer.forward + + def patched_decoder_layer_fwd(self, *args, **kwargs): + kwargs["past_key_value"] = kwargs.get("past_key_values") + return original_decoder_layer_forward(self, *args, **kwargs) + + kimi_k2_module.DeepseekV3DecoderLayer.forward = patched_decoder_layer_fwd + return getattr(kimi_k2_module, "DeepseekV3DecoderLayer")