Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activati
pruning_mixin:
_target_: modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin.FFNIntermediatePruningMixIn
layer_descriptor:
_target_: modelopt.torch.puzzletron.anymodel.models.qwen3_8b.qwen3_8b_model_descriptor.Qwen3_8BFFNIntermediateLayerDescriptor
_target_: modelopt.torch.puzzletron.anymodel.models.qwen3.qwen3_model_descriptor.Qwen3FFNIntermediateLayerDescriptor

hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IterativeChannelContributionHook}
activation_hooks_kwargs:
method: iterative
target_layer: "mlp.down_proj"
layer_input_descriptors_path:

intermediate_size_list: [256] # teacher_intermediate_size is 14336
intermediate_size_list: [256] # teacher_intermediate_size is 14336
mlp_init_mode: "PruneByActivationsLog"
3 changes: 1 addition & 2 deletions examples/puzzletron/mbridge_distillation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_hf.
--teacher_hf_path /path/to/teacher/huggingface/checkpoint \
--data_paths 1.0 /path/to/hf_datasets/wikitext-103-v1/Salesforce--wikitext_wikitext-103-v1_train_text_document \
--output_dir /path/to/distilled/checkpoint \
--hf-export-path /path/to/exported/hf/model \
--hf-model meta-llama/Llama-3.1-8B-Instruct \
--hf_export_path /path/to/exported/hf/model \
--seq_length 4096 \
--tp_size 8 \
--pp_size 1 \
Expand Down
75 changes: 16 additions & 59 deletions examples/puzzletron/mbridge_distillation/distill_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@

import argparse
import os
import traceback

import megatron.bridge.models.distillation_provider
import torch
from megatron.bridge import AutoBridge
from megatron.bridge.models.distillation_provider import convert_to_distillation_provider
from megatron.bridge.recipes.utils.optimizer_utils import (
distributed_fused_adam_with_cosine_annealing,
)
Expand All @@ -40,39 +39,16 @@
TokenizerConfig,
TrainingConfig,
)
from megatron.bridge.training.distill import distill
from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig
from megatron.core.datasets.utils import get_blend_from_list
from megatron.core.distributed import DistributedDataParallelConfig

# Import heterogeneous bridges BEFORE AutoBridge.from_hf_pretrained() is called to ensure
# registration takes precedence. The @MegatronModelBridge.register_bridge decorator registers
# bridges when the module is imported. If both LlamaBridge and PuzzletronLlamaAnyModelBridge
# register for the same source (LlamaForCausalLM), the dispatch system uses the last registration.
#
# Note: Currently, bridges are also registered when distillation_provider is imported
# below (via mbridge/__init__.py), but this import will be needed once DistillationProvider
# is upstreamed to Megatron-Bridge and we no longer import from modelopt.torch.puzzletron.
# Import to register heterogeneous bridges (side effect)
import modelopt.torch.puzzletron.export.mbridge # noqa: F401
import modelopt.torch.utils.distributed as dist

# Use local copy of distillation_provider with fix for heterogeneous models
# TODO: Remove this local copy once fix is upstreamed to Megatron-Bridge
from modelopt.torch.puzzletron.export.mbridge.distillation_provider import (
DistillationProvider,
convert_to_distillation_provider,
)
from modelopt.torch.puzzletron.export.mbridge.export_mbridge_to_hf import (
export_to_hf_and_copy_config,
)
from modelopt.torch.utils import print_rank_0

# Patch upstream module BEFORE importing distill() so isinstance checks work with our local DistillationProvider
# This must happen before distill() is imported because distill.py imports DistillationProvider at module load time
megatron.bridge.models.distillation_provider.DistillationProvider = DistillationProvider

# Import distill() AFTER patching so it uses the patched DistillationProvider
from megatron.bridge.training.distill import distill # noqa: E402

SEED = 1234


Expand Down Expand Up @@ -145,22 +121,13 @@ def get_args():
# Export arguments
parser.add_argument(
"--hf_export_path",
"--hf-export-path",
type=str,
default=None,
help=(
"Path where to save the HuggingFace export. "
"If provided, exports checkpoint to HF format after distillation."
"If provided, exports last iteration checkpoint to HF format after distillation."
),
)
parser.add_argument(
"--hf_model",
"--hf-model",
type=str,
required=True,
help="HuggingFace model ID to use as template for export (e.g., meta-llama/Llama-3.1-8B-Instruct). "
"Should match the base architecture of the student model.",
)
args = parser.parse_args()

# Sanity checks
Expand Down Expand Up @@ -288,42 +255,32 @@ def _build_model_provider(hf_path):

# Export to HuggingFace format if hf_export_path is provided
if args.hf_export_path:
# Wait for all ranks to finish distillation before export
if torch.distributed.is_initialized():
torch.distributed.barrier()

# Save rank before destroying process group (dist.rank() won't work after destruction)
is_rank_0 = dist.rank() == 0

# Destroy process group on all ranks - export_ckpt will create its own temporary one
# This prevents cleanup from hanging (cleanup tries to barrier, but rank 0 would be gone)
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
dist.cleanup()

# Only rank 0 exports
if is_rank_0:
try:
export_to_hf_and_copy_config(
student_hf_path=args.student_hf_path,
checkpoint_dir=checkpoint_dir,
train_iters=args.train_iters,
hf_export_path=args.hf_export_path,
hf_model=args.hf_model,
trust_remote_code=args.trust_remote_code,
)
except Exception as e:
print(f"⚠️ Export failed: {e}")
traceback.print_exc()
bridge = AutoBridge.from_hf_pretrained(
args.student_hf_path, trust_remote_code=args.trust_remote_code
)
# Create subblocks_safetensors directory else safetensors saving will fail
os.makedirs(os.path.join(args.hf_export_path, "subblocks_safetensors"), exist_ok=True)
bridge.export_ckpt(
megatron_path=f"{checkpoint_dir}/iter_{args.train_iters:07d}",
hf_path=args.hf_export_path,
show_progress=True,
strict=True,
)


if __name__ == "__main__":
dist.setup()
args = get_args()
try:
main(args)
except Exception as e:
print_rank_0(f"✗ MAIN FAILED: {type(e).__name__}: {e}")
print_rank_0(f"Traceback:\n{traceback.format_exc()}")
raise
finally:
dist.cleanup()
4 changes: 2 additions & 2 deletions modelopt/torch/puzzletron/anymodel/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@
from modelopt.torch.puzzletron.anymodel.models.nemotron_h import *
from modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2 import *
from modelopt.torch.puzzletron.anymodel.models.qwen2 import *
from modelopt.torch.puzzletron.anymodel.models.qwen3_8b import *
from modelopt.torch.puzzletron.anymodel.models.qwen3_vl_30b_a3b_instruct import *
from modelopt.torch.puzzletron.anymodel.models.qwen3 import *
from modelopt.torch.puzzletron.anymodel.models.qwen3_vl import *
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def init_rotary_embedding(model, runtime):
"""
NemotronH has no positional embeddings
"""
pass

@staticmethod
def input_embedding_name():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def init_rotary_embedding(model, runtime):
"""
NemotronH has no positional embeddings
"""
pass

@staticmethod
def input_embedding_name():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,5 +144,3 @@ class Qwen2FFNIntermediateLayerDescriptor(LlamaFFNIntermediateLayerDescriptor):

Qwen2 uses the same FFN structure as Llama (gate_proj, up_proj, down_proj).
"""

pass
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from modelopt.torch.puzzletron.anymodel.models.qwen3_vl_30b_a3b_instruct.qwen3_vl_30b_a3b_instruct_converter import (
Qwen3VL30BA3BInstructConverter,
)
from modelopt.torch.puzzletron.anymodel.models.qwen3_vl_30b_a3b_instruct.qwen3_vl_30b_a3b_instruct_model_descriptor import (
Qwen3VL30BA3BInstructModelDescriptor,
from modelopt.torch.puzzletron.anymodel.models.qwen3.qwen3_converter import Qwen3Converter
from modelopt.torch.puzzletron.anymodel.models.qwen3.qwen3_model_descriptor import (
Qwen3ModelDescriptor,
)
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


@ConverterFactory.register_decorator("qwen3")
class Qwen3_8BConverter(Converter):
class Qwen3Converter(Converter):
@staticmethod
def create_block_configs_from_main_config(config: Qwen3Config) -> List[BlockConfig]:
num_hidden_layers = config.num_hidden_layers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@


@ModelDescriptorFactory.register_decorator("qwen3")
class Qwen3_8BModelDescriptor(ModelDescriptor):
class Qwen3ModelDescriptor(ModelDescriptor):
@staticmethod
def decoder_layer_cls():
return Qwen3DecoderLayer
Expand Down Expand Up @@ -135,7 +135,7 @@ def build_attention_predicates() -> Dict[str, re.Pattern]:


@dataclass
class Qwen3_8BFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor):
class Qwen3FFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor):
down_proj_name: str = "mlp.down_proj"
ffn_prefix_name: str = "model.layers.{layer_idx}.mlp"
linear_weight_names: List[str] = field(
Expand All @@ -144,7 +144,7 @@ class Qwen3_8BFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor):


@dataclass
class Qwen3_8BKVHeadsLayerDescriptor(KVHeadsLayerDescriptor):
class Qwen3KVHeadsLayerDescriptor(KVHeadsLayerDescriptor):
o_proj_name: str = "self_attn.o_proj"
attn_prefix_name: str = "model.layers.{layer_idx}.self_attn"
qkvo_weight_names: List[str] = field(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from modelopt.torch.puzzletron.anymodel.models.qwen3_8b.qwen3_8b_converter import Qwen3_8BConverter
from modelopt.torch.puzzletron.anymodel.models.qwen3_8b.qwen3_8b_model_descriptor import (
Qwen3_8BModelDescriptor,
from modelopt.torch.puzzletron.anymodel.models.qwen3_vl.qwen3_vl_converter import Qwen3VLConverter
from modelopt.torch.puzzletron.anymodel.models.qwen3_vl.qwen3_vl_model_descriptor import (
Qwen3VLModelDescriptor,
)
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


@ConverterFactory.register_decorator("qwen3_vl")
class Qwen3VL30BA3BInstructConverter(Converter):
class Qwen3VLConverter(Converter):
@staticmethod
def create_block_configs_from_main_config(config: Qwen3VLMoeConfig) -> List[BlockConfig]:
# Qwen3-VL MoE has nested text_config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from dataclasses import dataclass, field
from typing import Dict, List

import torch.nn as nn
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
Qwen3VLMoeTextDecoderLayer,
Qwen3VLMoeTextRotaryEmbedding,
Expand All @@ -46,7 +45,7 @@


@ModelDescriptorFactory.register_decorator("qwen3_vl")
class Qwen3VL30BA3BInstructModelDescriptor(ModelDescriptor):
class Qwen3VLModelDescriptor(ModelDescriptor):
@staticmethod
def uses_autocast() -> bool:
"""
Expand Down Expand Up @@ -90,7 +89,7 @@ def mlp_no_op_post_init(decoder_layer: Qwen3VLMoeTextDecoderLayer):
@staticmethod
def init_rotary_embedding(model, runtime):
# Re-initialize text rotary embedding on correct device and dtype
text_config = Qwen3VL30BA3BInstructModelDescriptor.get_language_model_config(model.config)
text_config = Qwen3VLModelDescriptor.get_language_model_config(model.config)
model.model.language_model.rotary_emb = Qwen3VLMoeTextRotaryEmbedding(
config=text_config
).to(device=runtime.device, dtype=runtime.dtype)
Expand Down Expand Up @@ -171,7 +170,7 @@ def build_attention_predicates() -> Dict[str, re.Pattern]:


@dataclass
class Qwen3VL30BA3BInstructFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor):
class Qwen3VLFFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor):
down_proj_name: str = "mlp.down_proj"
ffn_prefix_name: str = "model.language_model.layers.{layer_idx}.mlp"
linear_weight_names: List[str] = field(
Expand All @@ -180,7 +179,7 @@ class Qwen3VL30BA3BInstructFFNIntermediateLayerDescriptor(FFNIntermediateLayerDe


@dataclass
class Qwen3VL30BA3BInstructKVHeadsLayerDescriptor(KVHeadsLayerDescriptor):
class Qwen3VLKVHeadsLayerDescriptor(KVHeadsLayerDescriptor):
o_proj_name: str = "self_attn.o_proj"
attn_prefix_name: str = "model.language_model.layers.{layer_idx}.self_attn"
qkvo_weight_names: List[str] = field(
Expand All @@ -189,7 +188,7 @@ class Qwen3VL30BA3BInstructKVHeadsLayerDescriptor(KVHeadsLayerDescriptor):


@dataclass
class Qwen3VL30BA3BInstructExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor):
class Qwen3VLExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor):
"""
Qwen3-VL MoE layer descriptor.

Expand All @@ -203,7 +202,7 @@ class Qwen3VL30BA3BInstructExpertRemovalLayerDescriptor(ExpertRemovalLayerDescri
moe_prefix_name: str = "model.language_model.layers.{layer_idx}.mlp"
# Router: Qwen3VLMoeTextTopKRouter has self.weight, no bias
router_weights: List[str] = field(default_factory=lambda: ["gate.weight"])
router_biases: List[str] = field(default_factory=lambda: [])
router_biases: List[str] = field(default_factory=list)
# Fused expert format: Qwen3VLMoeTextExperts stores all experts in single tensors
# with shape [num_experts, ...] instead of separate tensors per expert.
is_fused_experts: bool = True
Expand Down
Loading
Loading