-
Notifications
You must be signed in to change notification settings - Fork 3.5k
[Ascend] qwen optimization #12078
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[Ascend] qwen optimization #12078
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
82b26a3
[Ascend]Qwen performance optimization
Liwansi 399514e
merge
chenxu140 67ceb79
[Ascend] Qwen performance optimization
chenxu140 cb50af6
update torch_npu and triton-ascend version
iforgetmyname fe77c39
Merge branch 'main' into main_qwen
iforgetmyname 9011e22
update torch_npu version
iforgetmyname a2e5c5a
Merge branch 'main' into main_qwen
iforgetmyname File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,6 +45,10 @@ | |
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| if _is_npu: | ||
| import torch_npu | ||
|
|
||
|
|
||
| class DeepEPMoE(FusedMoE): | ||
| """ | ||
| MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main) | ||
|
|
@@ -411,9 +415,142 @@ def npu_fused_moe_without_routing_weights_bf16( | |
| return hidden_states | ||
|
|
||
|
|
||
| class NpuFuseEPMoE(DeepEPMoE): | ||
| def __init__( | ||
| self, | ||
| num_experts: int, | ||
| top_k: int, | ||
| hidden_size: int, | ||
| intermediate_size: int, | ||
| layer_id: int, | ||
| num_fused_shared_experts: int = 0, | ||
| params_dtype: Optional[torch.dtype] = None, | ||
| quant_config: Optional[QuantizationConfig] = None, | ||
| prefix: str = "", | ||
| activation: str = "silu", | ||
| routed_scaling_factor: Optional[float] = None, | ||
| ): | ||
| super().__init__( | ||
| num_experts=num_experts, | ||
| top_k=top_k, | ||
| hidden_size=hidden_size, | ||
| intermediate_size=intermediate_size, | ||
| layer_id=layer_id, | ||
| num_fused_shared_experts=num_fused_shared_experts, | ||
| params_dtype=params_dtype, | ||
| quant_config=quant_config, | ||
| prefix=prefix, | ||
| activation=activation, | ||
| routed_scaling_factor=routed_scaling_factor, | ||
| ) | ||
|
|
||
| self.quant_method.process_weights_after_loading = ( | ||
| self._process_weights_after_loading | ||
| ) | ||
|
|
||
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| topk_output: TopKOutput, | ||
| forward_shared_experts=None, | ||
| alt_stream=None, | ||
| disable_sbo=False, | ||
| ): | ||
| return self.dispatcher.dispatch( | ||
| hidden_states=hidden_states, | ||
| topk_output=topk_output, | ||
| gmm1_permuted_weight=self.w13_weight, | ||
| gmm1_permuted_weight_scale=self.w13_weight_scale, | ||
| gmm2_weight=self.w2_weight, | ||
| gmm2_weight_scale=self.w2_weight_scale, | ||
| ).hidden_state | ||
|
|
||
| def release_weight_cache(self, weight: torch.Tensor): | ||
| # .contiguous() introduces additional memory overhead and needs to be released using resize_(0) | ||
| origin_weight = weight.data.transpose(1, 2) | ||
| new_weight = origin_weight.contiguous() | ||
| origin_weight.untyped_storage().resize_(0) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i can't really appreciate anything that does fix this issue there using |
||
| return new_weight | ||
|
|
||
| def permute_w13_weight_scale(self, w: torch.Tensor, tile_n: int): | ||
| if tile_n % 2 != 0: | ||
| raise ValueError(f"tile_n must be even, got {tile_n}") | ||
|
|
||
| *dims, n = w.shape | ||
| if n % tile_n != 0: | ||
| raise ValueError(f"Last dimension {n} must be divisible by tile_n {tile_n}") | ||
|
|
||
| w_reshaped = w.reshape(*dims, 2, n // tile_n, tile_n // 2) | ||
|
|
||
| # Permute the last two dimensions. | ||
| perm_order = list(range(len(dims))) + [-2, -3, -1] | ||
| w_permuted = w_reshaped.permute(perm_order) | ||
|
|
||
| return w_permuted.reshape(*dims, n) | ||
|
|
||
| def reshape_w13_weight(self, weight: torch.Tensor, dim: int, chunk_size: int = 64): | ||
| # Achieving greater computing power through reshape on Ascend. | ||
| original_shape = weight.shape | ||
| if dim < 0: | ||
| dim += len(original_shape) | ||
|
|
||
| if original_shape[dim] % (2 * chunk_size) != 0: | ||
| raise ValueError( | ||
| f"Dimension {dim} size {original_shape[dim]} must be divisible by {2 * chunk_size}" | ||
| ) | ||
|
|
||
| new_shape = ( | ||
| *original_shape[:dim], | ||
| 2, | ||
| original_shape[dim] // (2 * chunk_size), | ||
| chunk_size, | ||
| *original_shape[dim + 1 :], | ||
| ) | ||
|
|
||
| weight = weight.view(new_shape) | ||
| weight = weight.transpose(dim, dim + 1).contiguous() | ||
|
|
||
| return weight.view(*original_shape[:dim], -1, *original_shape[dim + 1 :]) | ||
|
|
||
| def _process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
| w13 = self.release_weight_cache(layer.w13_weight) | ||
| torch_npu.npu_format_cast_(w13, 2) | ||
| cpu_w13 = w13.cpu() | ||
| w13 = self.reshape_w13_weight(cpu_w13, -1).npu() | ||
| torch_npu.npu_format_cast_(w13, 29) | ||
| layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False) | ||
|
|
||
| w2 = torch_npu.npu_format_cast(layer.w2_weight.data, 29) | ||
| layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False) | ||
|
|
||
| w13_scale = layer.w13_weight_scale.data.squeeze(-1).contiguous() | ||
| w13_scale = self.permute_w13_weight_scale(w13_scale, 128) | ||
| layer.w13_weight_scale = torch.nn.Parameter( | ||
| w13_scale.to(torch.float32), requires_grad=False | ||
| ) | ||
|
|
||
| w2_scale = layer.w2_weight_scale.data.squeeze(-1).contiguous() | ||
| layer.w2_weight_scale = torch.nn.Parameter( | ||
| w2_scale.to(torch.float32), requires_grad=False | ||
| ) | ||
|
|
||
| if hasattr(layer, "w13_weight_offset"): | ||
| layer.w13_weight_offset = torch.nn.Parameter( | ||
| layer.w13_weight_offset.data.squeeze(-1).contiguous(), | ||
| requires_grad=False, | ||
| ) | ||
| if hasattr(layer, "w2_weight_offset"): | ||
| layer.w2_weight_offset = torch.nn.Parameter( | ||
| layer.w2_weight_offset.data.squeeze(-1).contiguous(), | ||
| requires_grad=False, | ||
| ) | ||
|
|
||
|
|
||
| def get_moe_impl_class(quant_config: Optional[QuantizationConfig]): | ||
| if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake(): | ||
| return DeepEPMoE | ||
| if get_moe_a2a_backend().is_ascend_fuseep(): | ||
| return NpuFuseEPMoE | ||
|
|
||
| # NEW: Direct FP4 detection (bypasses EP requirements) | ||
| # Check for FP4 quantization with TRTLLM flag, regardless of EP | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move to separate files
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will be solved in #13359