-
Notifications
You must be signed in to change notification settings - Fork 31.7k
[Fp8] Fix experts
#43154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fp8] Fix experts
#43154
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes absolutely! I did not check but if a test is missing let's add one!
|
@3outeille note that this breaks TP but same as it is broken for now as well |
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| top_k_weights: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| final_hidden_states = torch.zeros_like(hidden_states) | ||
| num_experts = top_k_weights.shape[1] | ||
| with torch.no_grad(): | ||
| expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) | ||
| expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) | ||
| expert_mask = expert_mask.permute(2, 1, 0) | ||
| expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||
|
|
||
| for expert_idx in expert_hit: | ||
| expert_idx = expert_idx[0] | ||
| if expert_idx == num_experts: | ||
| if expert_idx == self.num_experts: | ||
| continue | ||
| _, token_idx = torch.where(expert_mask[expert_idx]) | ||
| current_state = hidden_states.index_select(0, token_idx) | ||
| top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) | ||
| current_state = hidden_states[token_idx] | ||
| gate, up = self.linear( | ||
| current_state, self.gate_up_proj[expert_idx], self.gate_up_proj_scale_inv[expert_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks ! maybe we can add a small comment to say that it was mostly copied from deepspeed_v3 modeling, so that we should propagate the changes here also in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment, referencing mixtral
|
[For maintainers] Suggested jobs to run (before merge) run-slow: finegrained_fp8 |
| @unittest.skip(reason="Dependent on #42028, will be removed alongside that PR") | ||
| def test_quantized_moe_forward(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This acts as a sanity integration check but it depends on the minimax m2 PR (#42028) so I will remove this skip when merging that PR
I think this is the easiest way as these weights force the issue
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43154&sha=d0d860 |
As per title, the current fp8 experts implementation is wrong - likely related to #42456 (cc @3outeille)
Before:
top kexperts that are hit, i.e.num_experts = top_k_weights.shape[1]top_kAfter:
transformers/src/transformers/models/mixtral/modeling_mixtral.py
Lines 82 to 86 in 88a5623
transformers/src/transformers/models/mixtral/modeling_mixtral.py
Lines 88 to 98 in 88a5623
transformers/src/transformers/models/mixtral/modeling_mixtral.py
Line 92 in 88a5623
I'm not sure how long it has been this way, but it definitely is not correct atm. I'd be curious if I'm overseeing something here but without this fix #42028 will not work for the original model - my suspicion is on mixtral having a bias towards low indices which might have saved this oversight.