Skip to content

Conversation

@vasqu
Copy link
Contributor

@vasqu vasqu commented Jan 7, 2026

As per title, the current fp8 experts implementation is wrong - likely related to #42456 (cc @3outeille)

Before:

  • The number of experts was treated as the number of top k experts that are hit, i.e. num_experts = top_k_weights.shape[1]
    • This is simply wrong as now our indices which can cover all experts is limited to top_k
    • This can trigger device assert errors as one hot might not have the correct indices for the number of classes indicated by num_experts
  • The hit expert is now used which is correct if we passed the selection before
    • Only due to the indices being correctly limited

After:

  • We now follow the current eager implementation, e.g. Mixtral:
    final_hidden_states = torch.zeros_like(hidden_states)
    with torch.no_grad():
    expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
    expert_mask = expert_mask.permute(2, 1, 0)
    expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
  • The portion where we iterate through experts should only have changed the linear calls of
    for expert_idx in expert_hit:
    expert_idx = expert_idx[0]
    if expert_idx == self.num_experts:
    continue
    top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
    current_state = hidden_states[token_idx]
    gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
    current_hidden_states = self.act_fn(gate) * up
    current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
    current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
    final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))

I'm not sure how long it has been this way, but it definitely is not correct atm. I'd be curious if I'm overseeing something here but without this fix #42028 will not work for the original model - my suspicion is on mixtral having a bias towards low indices which might have saved this oversight.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vasqu vasqu mentioned this pull request Jan 7, 2026
5 tasks
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes absolutely! I did not check but if a test is missing let's add one!

@ArthurZucker
Copy link
Collaborator

@3outeille note that this breaks TP but same as it is broken for now as well

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines 535 to 550
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
if expert_idx == self.num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states.index_select(0, token_idx)
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = self.linear(
current_state, self.gate_up_proj[expert_idx], self.gate_up_proj_scale_inv[expert_idx]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks ! maybe we can add a small comment to say that it was mostly copied from deepspeed_v3 modeling, so that we should propagate the changes here also in the future

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment, referencing mixtral

@github-actions
Copy link
Contributor

github-actions bot commented Jan 8, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: finegrained_fp8

Comment on lines +379 to +380
@unittest.skip(reason="Dependent on #42028, will be removed alongside that PR")
def test_quantized_moe_forward(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This acts as a sanity integration check but it depends on the minimax m2 PR (#42028) so I will remove this skip when merging that PR

I think this is the easiest way as these weights force the issue

@github-actions
Copy link
Contributor

github-actions bot commented Jan 8, 2026

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43154&sha=d0d860

@vasqu vasqu merged commit 9255982 into huggingface:main Jan 8, 2026
25 checks passed
@vasqu vasqu deleted the fix-fp8-experts branch January 8, 2026 16:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants