Skip to content

Support qwen3-next MTP Training#1575

Open
zx3xyy wants to merge 2 commits intoTHUDM:mainfrom
zx3xyy:pr/qwen3-next-mtp-core
Open

Support qwen3-next MTP Training#1575
zx3xyy wants to merge 2 commits intoTHUDM:mainfrom
zx3xyy:pr/qwen3-next-mtp-core

Conversation

@zx3xyy
Copy link

@zx3xyy zx3xyy commented Feb 11, 2026

Bug: linear_qgkv not handled in weight converter

Problem

Upstream's scripts/models/qwen3-next-80B-A3B.sh enables --use-gated-attention, which causes Megatron to create parameters named linear_qgkv.weight instead of linear_qkv.weight. But the weight converters (megatron_to_hf/qwen3_next.py and mbridge/qwen3_next.py) only handle linear_qkv, causing a crash during weight sync.

Root Cause

In Megatron's megatron/core/transformer/attention.py (line 891):

if self.config.use_gated_attention:
    self.linear_qgkv = build_module(...)   # Q+G+K+V fused
else:
    self.linear_qkv = build_module(...)    # Q+K+V fused

MTP layers inherit this config, so they also use linear_qgkv.

Reproduction

Test script that loads the upstream converter and calls it with linear_qgkv parameter names (as produced by Megatron when --use-gated-attention is set):

import re, torch, traceback
from types import SimpleNamespace

# Load the upstream converter function directly
with open('slime/backends/megatron_utils/megatron_to_hf/qwen3_next.py') as f:
    source = f.read()
exec(compile(source, 'qwen3_next.py', 'exec'))

# Qwen3-Next-80B-A3B model config
args = SimpleNamespace(
    num_attention_heads=16,
    num_query_groups=2,
    kv_channels=256,
    hidden_size=2048,
)
param = torch.randn(9216, 2048)

# Test 1: linear_qkv.weight (handled by upstream - works)
try:
    result = convert_qwen3_next_to_hf(args,
        'module.module.decoder.layers.0.self_attention.linear_qkv.weight', param)
    print(f'[OK] linear_qkv.weight -> {[r[0] for r in result]}')
except Exception as e:
    print(f'[FAIL] linear_qkv.weight: {e}')

# Test 2: linear_qgkv.weight (created when --use-gated-attention is set)
try:
    result = convert_qwen3_next_to_hf(args,
        'module.module.decoder.layers.0.self_attention.linear_qgkv.weight', param)
    print(f'[OK] linear_qgkv.weight -> {[r[0] for r in result]}')
except ValueError as e:
    traceback.print_exc()

# Test 3: linear_qgkv.layer_norm_weight
try:
    result = convert_qwen3_next_to_hf(args,
        'module.module.decoder.layers.0.self_attention.linear_qgkv.layer_norm_weight',
        torch.randn(2048))
    print(f'[OK] linear_qgkv.layer_norm_weight -> {[r[0] for r in result]}')
except ValueError as e:
    traceback.print_exc()

# Test 4: MTP layer with linear_qgkv (MTP inherits use_gated_attention from config)
try:
    result = convert_qwen3_next_to_hf(args,
        'module.module.mtp.layers.0.transformer_layer.self_attention.linear_qgkv.weight',
        param)
    print(f'[OK] MTP linear_qgkv.weight -> {[r[0] for r in result]}')
except ValueError as e:
    traceback.print_exc()

Output (before fix)

[OK] linear_qkv.weight -> ['model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight']

Traceback (most recent call last):
  File "qwen3_next.py", line 194, in convert_qwen3_next_to_hf
ValueError: Unknown parameter name: module.module.decoder.layers.0.self_attention.linear_qgkv.weight

Traceback (most recent call last):
  File "qwen3_next.py", line 194, in convert_qwen3_next_to_hf
ValueError: Unknown parameter name: module.module.decoder.layers.0.self_attention.linear_qgkv.layer_norm_weight

Traceback (most recent call last):
  File "qwen3_next.py", line 194, in convert_qwen3_next_to_hf
ValueError: Unknown parameter name: module.module.decoder.layers.0.self_attention.linear_qgkv.weight

Summary

[OK]            linear_qkv.weight             -> converts correctly
[BUG CONFIRMED] linear_qgkv.weight            -> ValueError: Unknown parameter name
[BUG CONFIRMED] linear_qgkv.layer_norm_weight -> ValueError: Unknown parameter name
[BUG CONFIRMED] MTP linear_qgkv.weight        -> ValueError: Unknown parameter name

After fix

[OK] linear_qkv.weight            -> ['model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight']
[OK] linear_qgkv.weight           -> ['model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight']
[OK] linear_qgkv.layer_norm_weight -> ['model.layers.0.input_layernorm.weight']
[OK] MTP linear_qgkv.weight       -> ['mtp.layers.0.self_attn.q_proj.weight', 'mtp.layers.0.self_attn.k_proj.weight', 'mtp.layers.0.self_attn.v_proj.weight']

Bug: Incorrect eh_proj.weight half-swap in MTP conversion

Problem

The _convert_mtp_layer function in megatron_to_hf/qwen3_next.py and _weight_to_mcore_format / _weight_to_hf_format in mbridge/qwen3_next.py swap the two halves of eh_proj.weight along dim=1 when converting between HF and Megatron formats:

first_half, second_half = param.chunk(2, dim=1)
new_param = torch.cat([second_half, first_half], dim=1)

This swap is incorrect — both HF and Megatron use the same input order [embedding, hidden_state].

Evidence

HF (SGLang) forward (sglang/python/sglang/srt/models/qwen3_next_mtp.py):

hidden_states = self.fc(torch.cat((input_embeds, hidden_states), dim=-1))

Megatron forward (megatron/core/transformer/multi_token_prediction.py):

hidden_states = torch.cat((decoder_input, hidden_states), -1)
hidden_states, _ = self.eh_proj(hidden_states)

Both concatenate [embedding, hidden_state] — same order. No swap needed.

Validation

  • v15 checkpoint (with swap): MTP loss abnormally high, model not learning
  • v16 checkpoint (without swap): MTP loss normal, matches v13 behavior

Fix

Remove the half-swap from both megatron_to_hf/qwen3_next.py and mbridge/qwen3_next.py. Pass eh_proj.weight through directly without modification.

Results

With these two fixes, we can try MTP layer of Qwen3-Next-80B-A3B with a reasonable loss.
Screenshot 2026-02-12 at 2 16 02 PM
Screenshot 2026-02-12 at 2 16 12 PM

…ayers

Qwen3-Next's full attention and MTP layers use gated attention (linear_qgkv)
with Q+G interleaved, unlike the linear attention layers that use linear_qkv.
This adds the missing QGKV per-group layout conversion for weight sync.

megatron_to_hf/qwen3_next.py:
- Add _convert_qgkv_weight_to_hf() for Megatron per-group layout
  [Q_g0, G_g0, K_g0, V_g0, ...] to HF [Q+G interleaved, K, V]
- Add linear_qgkv.weight and linear_qgkv.layer_norm_weight handlers

mbridge/qwen3_next.py:
- Add linear_qgkv entries to _ATTENTION_MAPPING
- Add _weight_to_mcore_format for gated QGKV (HF to Megatron direction)
- Fix _get_gptmodel_args to create MTP config with use_gated_attention=True

model_provider.py:
- Create separate mtp_config with use_gated_attention=True for MTP block spec
@zx3xyy zx3xyy marked this pull request as ready for review February 12, 2026 22:13
@zx3xyy zx3xyy changed the title Pr/qwen3 next mtp core Support qwen3-next MTP Training Feb 12, 2026
@guapisolo
Copy link
Contributor

Good job! I will help review this.

v = all_v.reshape(num_kv_heads * head_dim, hidden_size).contiguous()

return [
(f"{prefix}.self_attn.q_proj.weight", qg),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job!


weight = super()._weight_to_mcore_format(mcore_weights_name, hf_weights)
if mcore_weights_name.endswith("eh_proj.weight"):
first_half, second_half = weight.chunk(2, dim=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good fix.

Guard against division by zero in MTP loss computation when
num_tokens is 0, which can happen with context parallelism
when one CP rank has no response tokens after label rolling.
@zx3xyy zx3xyy force-pushed the pr/qwen3-next-mtp-core branch from 7f120ed to a4f2fd8 Compare February 13, 2026 07:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants