Skip to content

Conversation

@Liwansi
Copy link
Contributor

@Liwansi Liwansi commented Oct 24, 2025

Motivation

related to #10337

Modifications

-bugfix:

1.memory bugfix(w8a8_int8.py): in previous code, both layer.w13_weight and layer.w2_weight occupied double memory. now we solve it.
2.Cache Management Operation(CMO) bugfix(common.py):in some circumstances(BS=1,2), deadlock situations may occur due to issues with stream sync and waiting. now we solve it.
3.eplb bugfix:The eplb index operator is also introduced when the map path is not set.

-optimization:

1.using triton high-performance fused OP named split_qkv_rmsnorm_rope, which involve modifications to both rotary_embedding.py and qwen3_moe.py.
2.using triton high-performance OP named l1_norm, which involve modifications to moe/topk.py
3.support PA into NPUGraph after some relevant software packages have been released, which involve modifications to npu_graph_runner.py and ascend_backend.py. see [Ascend]optimize Qwen3 on Ascend #10574
4.add moe-a2a-backend named ascend_fuseep, only support decode when pd disaggregation, integration of dispatch、gmm1、swiglu、gmm2、combine OP. Performance improved by 10% on Qwen3 235B.

Accuracy Tests

  • Qwen-32B
image
  • Qwen-235B
image

Benchmarking and Profiling

  • Qwen-235B

because some bugs before, we only have now performance after we changed.

image

performance of fuseep:

fuseep deepep

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Liwansi, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a series of targeted optimizations for Qwen models, specifically for NPU (Ascend) hardware. The changes focus on enhancing the performance and memory efficiency of critical components such as paged attention, Mixture-of-Experts (MoE) gating, rotary position embeddings, and KV cache management by integrating specialized NPU kernels. Additionally, it includes a memory optimization for quantized weights and dynamic input handling for NPU graphs.

Highlights

  • NPU Paged Attention Optimization: Replaced the generic NPU fused attention kernel with a specialized _npu_paged_attention function, along with adjusted input tensor shapes and sequence length handling, to improve performance and efficiency for paged attention on Ascend devices.
  • MoE Gating Optimization: Introduced a new NPU fused kernel, npu_moe_gating_top_k_softmax, for Mixture-of-Experts (MoE) gating. This optimization, controllable via an environment variable, accelerates the expert selection process for MoE models like Qwen3-MoE.
  • Rotary Embedding Optimization: Implemented a specialized NPU kernel, npu_apply_rotary_pos_emb, for rotary position embeddings. This optimization, particularly for head_size=128, involves pre-computing and storing cosine/sine values and passing the layer ID for context.
  • KV Cache Update Optimization: Switched to a more efficient _npu_reshape_and_cache kernel for updating the KV cache on NPU, replacing npu_scatter_nd_update_. This change, coupled with explicit type casting for indices, aims to improve memory access patterns and reduce overhead.
  • Memory Optimization for Quantized Weights: Added explicit memory deallocation for intermediate transposed tensors (w13_weight and w2_weight) during the processing of quantized weights. This change aims to reduce the memory footprint by promptly releasing unused memory.
  • Dynamic NPU Graph Input Handling: Adapted the NPU graph input update mechanism to differentiate between Multi-Layer Attention (MLA) and non-MLA attention architectures. This ensures correct and optimized handling of sequence length inputs (actual_seq_lengths_kv vs. context_lens) for NPU graphs.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several optimizations for Qwen models on Ascend NPUs, including the use of new, specialized kernels for attention, MoE gating, and rotary embeddings. It also contains a memory optimization for weight loading. While most of the changes appear correct and beneficial, I've identified a critical issue in python/sglang/srt/mem_cache/memory_pool.py where an incorrect kernel is used for MLA KV cache updates, which will lead to runtime errors. Additionally, there is some commented-out code in python/sglang/srt/layers/moe/topk.py that should be removed for better code clarity. My review provides specific suggestions to address these points.

Comment on lines 1762 to 1776
import torch_npu
if loc.dtype != torch.int32:
loc = loc.to(torch.int32)

torch_npu._npu_reshape_and_cache(
key=cache_k,
value=cache_v,
key_cache=self.k_buffer[layer_id].view(
-1, self.page_size, self.head_num, self.head_dim
),
loc.view(-1, 1),
cache_v.view(-1, 1, self.qk_rope_head_dim),
value_cache=self.v_buffer[layer_id].view(
-1, self.page_size, self.head_num, self.head_dim
),
slot_indices=loc,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change introduces a bug. The class AscendMLAPagedTokenToKVPool does not have self.head_num or self.head_dim attributes, which will cause an AttributeError. This class is for MLA (Multi-Level Attention) and should use self.kv_lora_rank and self.qk_rope_head_dim.

Furthermore, the _npu_reshape_and_cache kernel seems designed for MHA (Multi-Head Attention) and may not be suitable for MLA, especially since MLA can have different dimensions for key and value components. The previous implementation using torch_npu.npu_scatter_nd_update_ was likely correct for MLA.

Also, the buffer indexing self.k_buffer[layer_id] is incorrect. It should be self.k_buffer[layer_id - self.start_layer] to account for pipeline parallelism stages.

Finally, the import torch_npu inside the method is against style guidelines and is redundant as it's already imported at the top of the file.

I recommend reverting this block to the previous implementation.

        torch_npu.npu_scatter_nd_update_(
            self.k_buffer[layer_id - self.start_layer].view(-1, 1, self.kv_lora_rank),
            loc.view(-1, 1),
            cache_k.view(-1, 1, self.kv_lora_rank),
        )
        torch_npu.npu_scatter_nd_update_(
            self.v_buffer[layer_id - self.start_layer].view(
                -1, 1, self.qk_rope_head_dim
            ),
            loc.view(-1, 1),
            cache_v.view(-1, 1, self.qk_rope_head_dim),
        )

Comment on lines 853 to 857
# assert (
# num_token_non_padded is None
# ), "num_token_non_padded is not yet supported in fused_topk_native"
# assert expert_location_dispatch_info is None
# assert not apply_routed_scaling_factor_on_output, "Not implemented"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Commented-out code should be removed to improve code clarity. If these assertions are no longer needed, please remove them instead of commenting them out.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove these useless code

@ping1jing2 ping1jing2 changed the title qwen optimization [Ascend] qwen optimization Oct 26, 2025
@ping1jing2 ping1jing2 marked this pull request as draft October 26, 2025 19:36
)
topk_weights = topk_weights.to(torch.float)
else:
topk_weights, topk_ids = fused_topk_native(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style: Put the common branch (NV GPU) as the first branch in the code whenever possible.

torch_npu.npu_scatter_nd_update_(
self.v_buffer[layer_id - self.start_layer].view(
-1, 1, self.qk_rope_head_dim
import torch_npu
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move the whole class into a new file memory_pool_acsend.py

)


def wait_cmo_stream():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move all npu-specifc functions into a separate file python/sglang/srt/utils/npu_common.py

correction_bias=correction_bias,
)
else:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this kernel has been introduced, pls pull the latest code

_is_cpu_amx_available = cpu_has_amx_support()
_is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_use_gating_topk_fused = get_bool_env_var("SGLANG_USE_GATING_TOPK_FUSED") and _is_npu
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need this env var, use this kernel by default if it's robust enough

cache = torch.cat((cos, sin), dim=-1)
return cache

def _get_cos_sin_with_position(self, positions, layer_id):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

give this a more general impl that benefits gpu as well

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this change

and self.compatible_with_fused_kv_buffer
else None
),
layer_id=self.layer_id if _is_npu else None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this file

@@ -1,2 +1,3 @@
# Temporarily do this to avoid changing all imports in the repo
from sglang.srt.utils.common import *
from sglang.srt.utils.npu_common import *
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to change in this pr


from sglang.srt.environ import envs
from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.utils.npu_common import get_npu_compiler_config, get_npu_memory_capacity
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this file, no need to change in this pr

raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}")


class NpuFuseEPMoE(DeepEPMoE):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to separate files

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will be solved in #13359

)

def forward_prepare(
def forward_prepare_npu(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please try to avoid this model change and create fusion pass for this kernel
We have supported fusion pass manager in this PR
#11104

@iforgetmyname
Copy link
Collaborator

/tag-and-rerun-ci

@ping1jing2 ping1jing2 self-assigned this Nov 20, 2025
@chenxu140 chenxu140 force-pushed the main_qwen branch 2 times, most recently from 96a5f1c to ba1e92f Compare November 21, 2025 10:34
# .contiguous() introduces additional memory overhead and needs to be released using resize_(0)
origin_weight = weight.data.transpose(1, 2)
new_weight = origin_weight.contiguous()
origin_weight.untyped_storage().resize_(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed this workaround, we have fix in this PR #11984
commit ff1a3bb

Copy link
Collaborator

@iforgetmyname iforgetmyname Nov 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i can't really appreciate anything that does fix this issue there

using flatten is just another workaround from my view

# .contiguous() introduces additional memory overhead and needs to be released using resize_(0)
origin_weight = weight.data.transpose(1, 2)
new_weight = origin_weight.contiguous()
origin_weight.untyped_storage().resize_(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed this workaround, we have fix in this PR #11984
commit ff1a3bb

@chenxu140 chenxu140 force-pushed the main_qwen branch 2 times, most recently from f5c8a7e to 8418724 Compare November 23, 2025 07:53
@iforgetmyname
Copy link
Collaborator

/tag-and-rerun-ci

@iforgetmyname iforgetmyname merged commit 432ecf8 into sgl-project:main Nov 25, 2025
81 of 89 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants