Skip to content

Commit 67ceb79

Browse files
committed
[Ascend] Qwen performance optimization
1 parent 399514e commit 67ceb79

File tree

14 files changed

+482
-69
lines changed

14 files changed

+482
-69
lines changed

python/sglang/srt/layers/attention/ascend_backend.py

Lines changed: 84 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -625,34 +625,93 @@ def forward_decode_graph(
625625
)
626626

627627
if not self.use_mla:
628-
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
629-
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
630-
query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)
631-
num_tokens = query.shape[0]
632-
attn_output = torch.empty(
633-
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
634-
dtype=query.dtype,
635-
device=query.device,
636-
)
637-
if self.forward_metadata.seq_lens_cpu_int is None:
638-
actual_seq_len_kv = torch.from_numpy(
639-
np.array(self.forward_metadata.seq_lens_cpu_list).astype(np.int32)
628+
num_tokens = q.shape[0]
629+
"""PA will support bs<tp in the later version of CANN"""
630+
if num_tokens < get_attention_tp_size():
631+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
632+
layer.layer_id
633+
).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
634+
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
635+
layer.layer_id
636+
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
637+
query = q.reshape(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
638+
if self.forward_metadata.seq_lens_cpu_int is None:
639+
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
640+
else:
641+
actual_seq_len_kv = (
642+
self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
643+
)
644+
num_tokens = query.shape[0]
645+
workspace = (
646+
torch_npu._npu_fused_infer_attention_score_get_max_workspace(
647+
query,
648+
k_cache,
649+
v_cache,
650+
block_table=self.forward_metadata.block_tables,
651+
block_size=self.page_size,
652+
num_heads=layer.tp_q_head_num,
653+
num_key_value_heads=layer.tp_k_head_num,
654+
input_layout="BSH",
655+
scale=layer.scaling,
656+
actual_seq_lengths_kv=actual_seq_len_kv,
657+
)
640658
)
659+
output = torch.empty(
660+
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
661+
dtype=q.dtype,
662+
device=q.device,
663+
)
664+
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
665+
torch_npu.npu_fused_infer_attention_score.out(
666+
query,
667+
k_cache,
668+
v_cache,
669+
block_table=self.forward_metadata.block_tables,
670+
block_size=self.page_size,
671+
num_heads=layer.tp_q_head_num,
672+
num_key_value_heads=layer.tp_k_head_num,
673+
input_layout="BSH",
674+
scale=layer.scaling,
675+
actual_seq_lengths_kv=actual_seq_len_kv,
676+
workspace=workspace,
677+
out=[output, softmax_lse],
678+
)
679+
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
641680
else:
642-
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_int
681+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
682+
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
683+
layer.layer_id
684+
)
685+
query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)
686+
num_tokens = query.shape[0]
687+
attn_output = torch.empty(
688+
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
689+
dtype=query.dtype,
690+
device=query.device,
691+
)
692+
if self.forward_metadata.seq_lens_cpu_int is None:
693+
actual_seq_len_kv = torch.from_numpy(
694+
np.array(self.forward_metadata.seq_lens_cpu_list).astype(
695+
np.int32
696+
)
697+
)
698+
else:
699+
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_int
643700

644-
torch_npu._npu_paged_attention(
645-
query=query,
646-
key_cache=k_cache,
647-
value_cache=v_cache,
648-
num_heads=layer.tp_q_head_num,
649-
num_kv_heads=layer.tp_k_head_num,
650-
scale_value=layer.scaling,
651-
block_table=self.forward_metadata.block_tables,
652-
context_lens=actual_seq_len_kv,
653-
out=attn_output,
654-
)
655-
return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
701+
torch_npu._npu_paged_attention(
702+
query=query,
703+
key_cache=k_cache,
704+
value_cache=v_cache,
705+
num_heads=layer.tp_q_head_num,
706+
num_kv_heads=layer.tp_k_head_num,
707+
scale_value=layer.scaling,
708+
block_table=self.forward_metadata.block_tables,
709+
context_lens=actual_seq_len_kv,
710+
out=attn_output,
711+
)
712+
return attn_output.view(
713+
num_tokens, layer.tp_q_head_num * layer.v_head_dim
714+
)
656715
else:
657716
c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
658717
k_rope_cache = k_rope.view(

python/sglang/srt/layers/moe/ep_moe/layer.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@
4545
logger = logging.getLogger(__name__)
4646

4747

48+
if _is_npu:
49+
import torch_npu
50+
51+
4852
class DeepEPMoE(FusedMoE):
4953
"""
5054
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
@@ -411,9 +415,142 @@ def npu_fused_moe_without_routing_weights_bf16(
411415
return hidden_states
412416

413417

418+
class NpuFuseEPMoE(DeepEPMoE):
419+
def __init__(
420+
self,
421+
num_experts: int,
422+
top_k: int,
423+
hidden_size: int,
424+
intermediate_size: int,
425+
layer_id: int,
426+
num_fused_shared_experts: int = 0,
427+
params_dtype: Optional[torch.dtype] = None,
428+
quant_config: Optional[QuantizationConfig] = None,
429+
prefix: str = "",
430+
activation: str = "silu",
431+
routed_scaling_factor: Optional[float] = None,
432+
):
433+
super().__init__(
434+
num_experts=num_experts,
435+
top_k=top_k,
436+
hidden_size=hidden_size,
437+
intermediate_size=intermediate_size,
438+
layer_id=layer_id,
439+
num_fused_shared_experts=num_fused_shared_experts,
440+
params_dtype=params_dtype,
441+
quant_config=quant_config,
442+
prefix=prefix,
443+
activation=activation,
444+
routed_scaling_factor=routed_scaling_factor,
445+
)
446+
447+
self.quant_method.process_weights_after_loading = (
448+
self._process_weights_after_loading
449+
)
450+
451+
def forward(
452+
self,
453+
hidden_states: torch.Tensor,
454+
topk_output: TopKOutput,
455+
forward_shared_experts=None,
456+
alt_stream=None,
457+
disable_sbo=False,
458+
):
459+
return self.dispatcher.dispatch(
460+
hidden_states=hidden_states,
461+
topk_output=topk_output,
462+
gmm1_permuted_weight=self.w13_weight,
463+
gmm1_permuted_weight_scale=self.w13_weight_scale,
464+
gmm2_weight=self.w2_weight,
465+
gmm2_weight_scale=self.w2_weight_scale,
466+
).hidden_state
467+
468+
def release_weight_cache(self, weight: torch.Tensor):
469+
# .contiguous() introduces additional memory overhead and needs to be released using resize_(0)
470+
origin_weight = weight.data.transpose(1, 2)
471+
new_weight = origin_weight.contiguous()
472+
origin_weight.untyped_storage().resize_(0)
473+
return new_weight
474+
475+
def permute_w13_weight_scale(self, w: torch.Tensor, tile_n: int):
476+
if tile_n % 2 != 0:
477+
raise ValueError(f"tile_n must be even, got {tile_n}")
478+
479+
*dims, n = w.shape
480+
if n % tile_n != 0:
481+
raise ValueError(f"Last dimension {n} must be divisible by tile_n {tile_n}")
482+
483+
w_reshaped = w.reshape(*dims, 2, n // tile_n, tile_n // 2)
484+
485+
# Permute the last two dimensions.
486+
perm_order = list(range(len(dims))) + [-2, -3, -1]
487+
w_permuted = w_reshaped.permute(perm_order)
488+
489+
return w_permuted.reshape(*dims, n)
490+
491+
def reshape_w13_weight(self, weight: torch.Tensor, dim: int, chunk_size: int = 64):
492+
# Achieving greater computing power through reshape on Ascend.
493+
original_shape = weight.shape
494+
if dim < 0:
495+
dim += len(original_shape)
496+
497+
if original_shape[dim] % (2 * chunk_size) != 0:
498+
raise ValueError(
499+
f"Dimension {dim} size {original_shape[dim]} must be divisible by {2 * chunk_size}"
500+
)
501+
502+
new_shape = (
503+
*original_shape[:dim],
504+
2,
505+
original_shape[dim] // (2 * chunk_size),
506+
chunk_size,
507+
*original_shape[dim + 1 :],
508+
)
509+
510+
weight = weight.view(new_shape)
511+
weight = weight.transpose(dim, dim + 1).contiguous()
512+
513+
return weight.view(*original_shape[:dim], -1, *original_shape[dim + 1 :])
514+
515+
def _process_weights_after_loading(self, layer: torch.nn.Module) -> None:
516+
w13 = self.release_weight_cache(layer.w13_weight)
517+
torch_npu.npu_format_cast_(w13, 2)
518+
cpu_w13 = w13.cpu()
519+
w13 = self.reshape_w13_weight(cpu_w13, -1).npu()
520+
torch_npu.npu_format_cast_(w13, 29)
521+
layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
522+
523+
w2 = torch_npu.npu_format_cast(layer.w2_weight.data, 29)
524+
layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
525+
526+
w13_scale = layer.w13_weight_scale.data.squeeze(-1).contiguous()
527+
w13_scale = self.permute_w13_weight_scale(w13_scale, 128)
528+
layer.w13_weight_scale = torch.nn.Parameter(
529+
w13_scale.to(torch.float32), requires_grad=False
530+
)
531+
532+
w2_scale = layer.w2_weight_scale.data.squeeze(-1).contiguous()
533+
layer.w2_weight_scale = torch.nn.Parameter(
534+
w2_scale.to(torch.float32), requires_grad=False
535+
)
536+
537+
if hasattr(layer, "w13_weight_offset"):
538+
layer.w13_weight_offset = torch.nn.Parameter(
539+
layer.w13_weight_offset.data.squeeze(-1).contiguous(),
540+
requires_grad=False,
541+
)
542+
if hasattr(layer, "w2_weight_offset"):
543+
layer.w2_weight_offset = torch.nn.Parameter(
544+
layer.w2_weight_offset.data.squeeze(-1).contiguous(),
545+
requires_grad=False,
546+
)
547+
548+
414549
def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
415550
if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
416551
return DeepEPMoE
552+
if get_moe_a2a_backend().is_ascend_fuseep():
553+
return NpuFuseEPMoE
417554

418555
# NEW: Direct FP4 detection (bypasses EP requirements)
419556
# Check for FP4 quantization with TRTLLM flag, regardless of EP

python/sglang/srt/layers/moe/fused_moe_triton/layer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,18 @@ def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher:
9393
async_finish=True,
9494
return_recv_hook=True,
9595
)
96+
elif a2a_backend.is_ascend_fuseep():
97+
from sglang.srt.layers.moe.token_dispatcher import NpuFuseEPDispatcher
98+
99+
return NpuFuseEPDispatcher(
100+
group=get_tp_group().device_group,
101+
router_topk=moe_runner_config.top_k,
102+
permute_fusion=True,
103+
num_experts=moe_runner_config.num_experts,
104+
num_local_experts=moe_runner_config.num_local_experts,
105+
hidden_size=moe_runner_config.hidden_size,
106+
params_dtype=moe_runner_config.params_dtype,
107+
)
96108
else:
97109
raise NotImplementedError(f"Unsupported a2a backend: {a2a_backend}")
98110

python/sglang/srt/layers/moe/token_dispatcher/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
DeepEPNormalCombineInput,
1717
DeepEPNormalDispatchOutput,
1818
)
19+
from sglang.srt.layers.moe.token_dispatcher.fuseep import NpuFuseEPDispatcher
1920
from sglang.srt.layers.moe.token_dispatcher.mooncake import (
2021
MooncakeCombineInput,
2122
MooncakeDispatchOutput,
@@ -48,4 +49,5 @@
4849
"DeepEPLLDispatchOutput",
4950
"DeepEPLLCombineInput",
5051
"DeepEPNormalCombineInput",
52+
"NpuFuseEPDispatcher",
5153
]

0 commit comments

Comments
 (0)