Skip to content

Commit 4ce70e6

Browse files
committed
NPU Graph Compilation support and PassManager with AddRmsNorm & Quantize
fuse
1 parent 6237754 commit 4ce70e6

File tree

19 files changed

+700
-167
lines changed

19 files changed

+700
-167
lines changed

python/sglang/srt/_custom_ops.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,53 @@
44

55
import torch
66

7-
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu
7+
from sglang.srt.utils import (
8+
direct_register_custom_op,
9+
get_bool_env_var,
10+
get_cmo_stream,
11+
is_hip,
12+
is_hpu,
13+
is_npu,
14+
)
815

916
logger = logging.getLogger(__name__)
1017
use_vllm_custom_allreduce = get_bool_env_var(
1118
"USE_VLLM_CUSTOM_ALLREDUCE", default="false"
1219
)
1320

21+
22+
import sglang.srt.utils
23+
24+
25+
@torch.library.custom_op("sglang::wait_cmo_stream", mutates_args=())
26+
def wait_cmo_stream() -> None:
27+
if is_npu() and get_cmo_stream():
28+
sglang.srt.utils.wait_cmo_stream()
29+
30+
31+
@wait_cmo_stream.register_fake
32+
def wait_cmo_stream_fake() -> None:
33+
pass
34+
35+
36+
def prepare_weight_cache(handle: torch.Tensor, cache: List[torch.Tensor]) -> None:
37+
sglang.srt.utils.prepare_weight_cache(handle, cache)
38+
39+
40+
def prepare_weight_cache_register_fake(
41+
handle: torch.Tensor, cache: List[torch.Tensor]
42+
) -> None:
43+
pass
44+
45+
46+
direct_register_custom_op(
47+
op_name="prepare_weight_cache",
48+
op_func=prepare_weight_cache,
49+
mutates_args=["handle"],
50+
fake_impl=prepare_weight_cache_register_fake,
51+
)
52+
53+
1454
if not is_hpu():
1555
# ROCm does not use vllm custom allreduce
1656
if use_vllm_custom_allreduce and not is_hip():

python/sglang/srt/configs/model_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(
9797
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
9898
sampling_defaults: str = "openai",
9999
quantize_and_serve: bool = False,
100+
enable_torch_compile: bool = False,
100101
) -> None:
101102
# Parse args
102103
self.model_path = model_path
@@ -106,6 +107,7 @@ def __init__(
106107
self.model_impl = model_impl
107108
self.sampling_defaults = sampling_defaults
108109
self.quantize_and_serve = quantize_and_serve
110+
self.enable_torch_compile = enable_torch_compile
109111

110112
# Validate quantize_and_serve configuration
111113
self._validate_quantize_and_serve_config()
@@ -234,6 +236,7 @@ def from_server_args(
234236
model_impl=server_args.model_impl,
235237
sampling_defaults=server_args.sampling_defaults,
236238
quantize_and_serve=server_args.quantize_and_serve,
239+
enable_torch_compile=server_args.enable_torch_compile,
237240
**kwargs,
238241
)
239242

python/sglang/srt/layers/attention/ascend_backend.py

Lines changed: 136 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def update_verify_buffers_to_fill_after_draft(
7474

7575
def __init__(self, model_runner: ModelRunner):
7676
super().__init__()
77+
self.enable_torch_compile = False
7778
self.forward_metadata = None
7879
self.device = model_runner.device
7980
self.page_size = model_runner.page_size
@@ -576,112 +577,151 @@ def forward_decode_graph(
576577
layer, forward_batch.out_cache_loc, k, v
577578
)
578579

579-
if not self.use_mla:
580-
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
581-
layer.layer_id
582-
).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
583-
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
584-
layer.layer_id
585-
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
586-
query = q.reshape(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
587-
if self.forward_metadata.seq_lens_cpu_int is None:
588-
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
589-
else:
590-
actual_seq_len_kv = (
591-
self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
592-
)
580+
if not self.use_mla and self.enable_torch_compile:
581+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
582+
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
583+
query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)
593584
num_tokens = query.shape[0]
594-
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
595-
query,
596-
k_cache,
597-
v_cache,
598-
block_table=self.forward_metadata.block_tables,
599-
block_size=self.page_size,
600-
num_heads=layer.tp_q_head_num,
601-
num_key_value_heads=layer.tp_k_head_num,
602-
input_layout="BSH",
603-
scale=layer.scaling,
604-
actual_seq_lengths_kv=actual_seq_len_kv,
605-
)
606-
output = torch.empty(
607-
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
608-
dtype=q.dtype,
609-
device=q.device,
610-
)
611-
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
612-
torch_npu.npu_fused_infer_attention_score.out(
613-
query,
614-
k_cache,
615-
v_cache,
616-
block_table=self.forward_metadata.block_tables,
617-
block_size=self.page_size,
618-
num_heads=layer.tp_q_head_num,
619-
num_key_value_heads=layer.tp_k_head_num,
620-
input_layout="BSH",
621-
scale=layer.scaling,
622-
actual_seq_lengths_kv=actual_seq_len_kv,
623-
workspace=workspace,
624-
out=[output, softmax_lse],
625-
)
626-
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
627-
else:
628-
c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
629-
k_rope_cache = k_rope.view(
630-
-1, layer.tp_k_head_num, self.page_size, self.qk_rope_head_dim
631-
)
632-
c_kv_cache = c_kv.view(
633-
-1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
585+
attn_output = torch.empty(
586+
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
587+
dtype=query.dtype,
588+
device=query.device,
634589
)
635590

636-
q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank).contiguous()
637-
q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim)
638591
if self.forward_metadata.seq_lens_cpu_int is None:
639-
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
640-
else:
641-
actual_seq_len_kv = (
642-
self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
592+
actual_seq_len_kv = torch.from_numpy(
593+
np.array(self.forward_metadata.seq_lens_cpu_list).astype(np.int32)
643594
)
595+
else:
596+
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_int
644597

645-
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
646-
q_nope,
647-
c_kv_cache,
648-
c_kv_cache,
649-
query_rope=q_rope,
650-
key_rope=k_rope_cache,
598+
torch_npu._npu_paged_attention(
599+
query=query,
600+
key_cache=k_cache,
601+
value_cache=v_cache,
651602
num_heads=layer.tp_q_head_num,
652-
num_key_value_heads=layer.tp_k_head_num,
603+
num_kv_heads=layer.tp_k_head_num,
604+
scale_value=layer.scaling,
653605
block_table=self.forward_metadata.block_tables,
654-
block_size=self.page_size,
655-
input_layout="BNSD",
656-
scale=layer.scaling,
657-
actual_seq_lengths_kv=actual_seq_len_kv,
658-
antiquant_mode=0,
659-
antiquant_scale=None,
660-
sparse_mode=0,
606+
context_lens=actual_seq_len_kv,
607+
out=attn_output,
661608
)
662-
output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device)
663-
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
609+
return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
610+
else:
611+
if not self.use_mla:
612+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
613+
layer.layer_id
614+
).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
615+
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
616+
layer.layer_id
617+
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
618+
query = q.reshape(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
619+
if self.forward_metadata.seq_lens_cpu_int is None:
620+
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
621+
else:
622+
actual_seq_len_kv = (
623+
self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
624+
)
625+
num_tokens = query.shape[0]
626+
workspace = (
627+
torch_npu._npu_fused_infer_attention_score_get_max_workspace(
628+
query,
629+
k_cache,
630+
v_cache,
631+
block_table=self.forward_metadata.block_tables,
632+
block_size=self.page_size,
633+
num_heads=layer.tp_q_head_num,
634+
num_key_value_heads=layer.tp_k_head_num,
635+
input_layout="BSH",
636+
scale=layer.scaling,
637+
actual_seq_lengths_kv=actual_seq_len_kv,
638+
)
639+
)
640+
output = torch.empty(
641+
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
642+
dtype=q.dtype,
643+
device=q.device,
644+
)
645+
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
646+
torch_npu.npu_fused_infer_attention_score.out(
647+
query,
648+
k_cache,
649+
v_cache,
650+
block_table=self.forward_metadata.block_tables,
651+
block_size=self.page_size,
652+
num_heads=layer.tp_q_head_num,
653+
num_key_value_heads=layer.tp_k_head_num,
654+
input_layout="BSH",
655+
scale=layer.scaling,
656+
actual_seq_lengths_kv=actual_seq_len_kv,
657+
workspace=workspace,
658+
out=[output, softmax_lse],
659+
)
660+
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
661+
else:
662+
c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer(
663+
layer.layer_id
664+
)
665+
k_rope_cache = k_rope.view(
666+
-1, layer.tp_k_head_num, self.page_size, self.qk_rope_head_dim
667+
)
668+
c_kv_cache = c_kv.view(
669+
-1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
670+
)
664671

665-
torch_npu.npu_fused_infer_attention_score.out(
666-
q_nope,
667-
c_kv_cache,
668-
c_kv_cache,
669-
query_rope=q_rope,
670-
key_rope=k_rope_cache,
671-
num_heads=layer.tp_q_head_num,
672-
num_key_value_heads=layer.tp_k_head_num,
673-
block_table=self.forward_metadata.block_tables,
674-
block_size=self.page_size,
675-
input_layout="BNSD",
676-
scale=layer.scaling,
677-
actual_seq_lengths_kv=actual_seq_len_kv,
678-
antiquant_mode=0,
679-
antiquant_scale=None,
680-
sparse_mode=0,
681-
workspace=workspace,
682-
out=[output, softmax_lse],
683-
)
684-
return output.view(-1, layer.tp_q_head_num * self.kv_lora_rank)
672+
q_nope = q.view(
673+
-1, layer.tp_q_head_num, 1, self.kv_lora_rank
674+
).contiguous()
675+
q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim)
676+
if self.forward_metadata.seq_lens_cpu_int is None:
677+
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
678+
else:
679+
actual_seq_len_kv = (
680+
self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
681+
)
682+
683+
workspace = (
684+
torch_npu._npu_fused_infer_attention_score_get_max_workspace(
685+
q_nope,
686+
c_kv_cache,
687+
c_kv_cache,
688+
query_rope=q_rope,
689+
key_rope=k_rope_cache,
690+
num_heads=layer.tp_q_head_num,
691+
num_key_value_heads=layer.tp_k_head_num,
692+
block_table=self.forward_metadata.block_tables,
693+
block_size=self.page_size,
694+
input_layout="BNSD",
695+
scale=layer.scaling,
696+
actual_seq_lengths_kv=actual_seq_len_kv,
697+
antiquant_mode=0,
698+
antiquant_scale=None,
699+
sparse_mode=0,
700+
)
701+
)
702+
output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device)
703+
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
704+
705+
torch_npu.npu_fused_infer_attention_score.out(
706+
q_nope,
707+
c_kv_cache,
708+
c_kv_cache,
709+
query_rope=q_rope,
710+
key_rope=k_rope_cache,
711+
num_heads=layer.tp_q_head_num,
712+
num_key_value_heads=layer.tp_k_head_num,
713+
block_table=self.forward_metadata.block_tables,
714+
block_size=self.page_size,
715+
input_layout="BNSD",
716+
scale=layer.scaling,
717+
actual_seq_lengths_kv=actual_seq_len_kv,
718+
antiquant_mode=0,
719+
antiquant_scale=None,
720+
sparse_mode=0,
721+
workspace=workspace,
722+
out=[output, softmax_lse],
723+
)
724+
return output.view(-1, layer.tp_q_head_num * self.kv_lora_rank)
685725

686726
def forward_decode(
687727
self,

python/sglang/srt/layers/communicator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-2024 SGLang Team
1+
# Copyright 2023-2025 SGLang Team
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -51,7 +51,6 @@
5151
is_hip,
5252
is_sm90_supported,
5353
is_sm100_supported,
54-
prepare_weight_cache,
5554
)
5655

5756
_is_flashinfer_available = is_flashinfer_available()
@@ -567,7 +566,7 @@ def _gather_hidden_states_and_residual(
567566
else:
568567
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
569568
if context.cache is not None:
570-
_ = prepare_weight_cache(hidden_states, context.cache)
569+
torch.ops.sglang.prepare_weight_cache(hidden_states, context.cache)
571570
hidden_states, residual = layernorm(hidden_states, residual)
572571
return hidden_states, residual
573572

0 commit comments

Comments
 (0)