diff --git a/docker/npu.Dockerfile b/docker/npu.Dockerfile index 21a8f7edffb..54261e708a2 100644 --- a/docker/npu.Dockerfile +++ b/docker/npu.Dockerfile @@ -10,11 +10,10 @@ ARG PIP_INDEX_URL="https://pypi.org/simple/" ARG APTMIRROR="" ARG PYTORCH_VERSION="2.8.0" ARG TORCHVISION_VERSION="0.23.0" -ARG PTA_VERSION="v7.2.0-pytorch${PYTORCH_VERSION}" -ARG PTA_NAME="torch_npu-${PYTORCH_VERSION}-cp311-cp311-manylinux_2_28_aarch64.whl" -ARG PTA_URL="https://gitcode.com/Ascend/pytorch/releases/download/${PTA_VERSION}/${PTA_NAME}" -ARG TRITON_ASCEND_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/triton_ascend-3.2.0%2Bgitb0ea0850-cp311-cp311-linux_aarch64.whl" -ARG BISHENG_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/Ascend-BiSheng-toolkit_aarch64.run" +ARG PTA_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/torch_npu/torch_npu-2.8.0.post2.dev20251113-cp311-cp311-manylinux_2_28_aarch64.whl" +ARG TRITON_ASCEND_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/triton_ascend/triton_ascend-3.2.0.dev2025112116-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl" +ARG BISHENG_NAME="Ascend-BiSheng-toolkit_aarch64_20251121.run" +ARG BISHENG_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/triton_ascend/${BISHENG_NAME}" ARG SGLANG_TAG=main ARG ASCEND_CANN_PATH=/usr/local/Ascend/ascend-toolkit ARG SGLANG_KERNEL_NPU_TAG=main @@ -64,13 +63,13 @@ RUN ${PIP_INSTALL} sglang-router ### Install PyTorch and PTA -RUN (${PIP_INSTALL} torch==${PYTORCH_VERSION} torchvision==${TORCHVISION_VERSION} --index-url https://download.pytorch.org/whl/cpu) && \ - (wget -O "${PTA_NAME}" "${PTA_URL}" && ${PIP_INSTALL} "./${PTA_NAME}" && rm "./${PTA_NAME}") +RUN (${PIP_INSTALL} torch==${PYTORCH_VERSION} torchvision==${TORCHVISION_VERSION} --index-url https://download.pytorch.org/whl/cpu) \ + && (${PIP_INSTALL} ${PTA_URL}) # TODO: install from pypi released triton-ascend -RUN ${PIP_INSTALL} attrs==24.2.0 numpy==1.26.4 scipy==1.13.1 decorator==5.1.1 psutil==6.0.0 pytest==8.3.2 pytest-xdist==3.6.1 pyyaml pybind11 && \ - ${PIP_INSTALL} ${TRITON_ASCEND_URL} +RUN (${PIP_INSTALL} pybind11) \ + && (${PIP_INSTALL} ${TRITON_ASCEND_URL}) # Install SGLang RUN git clone https://github.com/sgl-project/sglang --branch $SGLANG_TAG && \ @@ -96,6 +95,6 @@ RUN wget https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/ops/CANN-custom_o ${PIP_INSTALL} ./custom_ops-1.0.$DEVICE_TYPE-cp311-cp311-linux_aarch64.whl # Install Bisheng -RUN wget ${BISHENG_URL} && chmod a+x Ascend-BiSheng-toolkit_aarch64.run && ./Ascend-BiSheng-toolkit_aarch64.run --install && rm Ascend-BiSheng-toolkit_aarch64.run +RUN wget -O "${BISHENG_NAME}" "${BISHENG_URL}" && chmod a+x "${BISHENG_NAME}" && "./${BISHENG_NAME}" --install && rm "${BISHENG_NAME}" CMD ["/bin/bash"] diff --git a/python/sglang/srt/layers/attention/ascend_backend.py b/python/sglang/srt/layers/attention/ascend_backend.py index 54668423e4a..da1809816b6 100644 --- a/python/sglang/srt/layers/attention/ascend_backend.py +++ b/python/sglang/srt/layers/attention/ascend_backend.py @@ -625,53 +625,93 @@ def forward_decode_graph( ) if not self.use_mla: - k_cache = forward_batch.token_to_kv_pool.get_key_buffer( - layer.layer_id - ).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim) - v_cache = forward_batch.token_to_kv_pool.get_value_buffer( - layer.layer_id - ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim) - query = q.reshape(-1, 1, layer.tp_q_head_num * layer.qk_head_dim) - if self.forward_metadata.seq_lens_cpu_int is None: - actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list + num_tokens = q.shape[0] + """PA will support bs None: + w13 = self.release_weight_cache(layer.w13_weight) + torch_npu.npu_format_cast_(w13, 2) + cpu_w13 = w13.cpu() + w13 = self.reshape_w13_weight(cpu_w13, -1).npu() + torch_npu.npu_format_cast_(w13, 29) + layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False) + + w2 = torch_npu.npu_format_cast(layer.w2_weight.data, 29) + layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False) + + w13_scale = layer.w13_weight_scale.data.squeeze(-1).contiguous() + w13_scale = self.permute_w13_weight_scale(w13_scale, 128) + layer.w13_weight_scale = torch.nn.Parameter( + w13_scale.to(torch.float32), requires_grad=False + ) + + w2_scale = layer.w2_weight_scale.data.squeeze(-1).contiguous() + layer.w2_weight_scale = torch.nn.Parameter( + w2_scale.to(torch.float32), requires_grad=False + ) + + if hasattr(layer, "w13_weight_offset"): + layer.w13_weight_offset = torch.nn.Parameter( + layer.w13_weight_offset.data.squeeze(-1).contiguous(), + requires_grad=False, + ) + if hasattr(layer, "w2_weight_offset"): + layer.w2_weight_offset = torch.nn.Parameter( + layer.w2_weight_offset.data.squeeze(-1).contiguous(), + requires_grad=False, + ) + + def get_moe_impl_class(quant_config: Optional[QuantizationConfig]): if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake(): return DeepEPMoE + if get_moe_a2a_backend().is_ascend_fuseep(): + return NpuFuseEPMoE # NEW: Direct FP4 detection (bypasses EP requirements) # Check for FP4 quantization with TRTLLM flag, regardless of EP diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 8b7350c98df..0847d62f24e 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -93,6 +93,18 @@ def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher: async_finish=True, return_recv_hook=True, ) + elif a2a_backend.is_ascend_fuseep(): + from sglang.srt.layers.moe.token_dispatcher import NpuFuseEPDispatcher + + return NpuFuseEPDispatcher( + group=get_tp_group().device_group, + router_topk=moe_runner_config.top_k, + permute_fusion=True, + num_experts=moe_runner_config.num_experts, + num_local_experts=moe_runner_config.num_local_experts, + hidden_size=moe_runner_config.hidden_size, + params_dtype=moe_runner_config.params_dtype, + ) else: raise NotImplementedError(f"Unsupported a2a backend: {a2a_backend}") diff --git a/python/sglang/srt/layers/moe/token_dispatcher/__init__.py b/python/sglang/srt/layers/moe/token_dispatcher/__init__.py index 26909c6c17a..1eac345ce89 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/__init__.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/__init__.py @@ -16,6 +16,7 @@ DeepEPNormalCombineInput, DeepEPNormalDispatchOutput, ) +from sglang.srt.layers.moe.token_dispatcher.fuseep import NpuFuseEPDispatcher from sglang.srt.layers.moe.token_dispatcher.mooncake import ( MooncakeCombineInput, MooncakeDispatchOutput, @@ -48,4 +49,5 @@ "DeepEPLLDispatchOutput", "DeepEPLLCombineInput", "DeepEPNormalCombineInput", + "NpuFuseEPDispatcher", ] diff --git a/python/sglang/srt/layers/moe/token_dispatcher/fuseep.py b/python/sglang/srt/layers/moe/token_dispatcher/fuseep.py new file mode 100644 index 00000000000..c187352c3b8 --- /dev/null +++ b/python/sglang/srt/layers/moe/token_dispatcher/fuseep.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import logging +from typing import NamedTuple + +import torch + +from sglang.srt.layers.moe.token_dispatcher.base import ( + BaseDispatcher, + CombineInput, + CombineInputFormat, + DispatchOutput, + DispatchOutputFormat, +) +from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPBuffer +from sglang.srt.layers.moe.topk import TopKOutput +from sglang.srt.layers.moe.utils import DeepEPMode +from sglang.srt.utils import get_int_env_var + +logger = logging.getLogger(__name__) + + +class FuseEPDispatchOutput(NamedTuple): + """DeepEP low latency dispatch output.""" + + hidden_state: torch.Tensor + + @property + def format(self) -> DispatchOutputFormat: + return DispatchOutputFormat.DEEPEP_LL + + +class FuseEPCombineInput(NamedTuple): + """DeepEP low latency combine input.""" + + hidden_state: torch.Tensor + + @property + def format(self) -> CombineInputFormat: + return CombineInputFormat.DEEPEP_LL + + +class NpuFuseEPDispatcher(BaseDispatcher): + def __init__( + self, + group: torch.distributed.ProcessGroup, + router_topk: int, + permute_fusion: bool = False, + num_experts: int = None, + num_local_experts: int = None, + hidden_size: int = None, + params_dtype: torch.dtype = None, + deepep_mode: DeepEPMode = DeepEPMode.LOW_LATENCY, + ): + self.group = group + self.router_topk = router_topk + self.permute_fusion = permute_fusion + self.num_experts = num_experts + self.num_local_experts = num_local_experts + self.hidden_size = hidden_size + self.params_dtype = params_dtype + self.deepep_mode = deepep_mode + + self.params_bytes = 2 + self.num_max_dispatch_tokens_per_rank = get_int_env_var( + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128 + ) + + def dispatch( + self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs + ) -> DispatchOutput: + hidden_states, _ = self._get_buffer().fused_deep_moe( + hidden_states, + topk_idx=topk_output.topk_ids, + topk_weights=topk_output.topk_weights, + gmm1_permuted_weight=kwargs["gmm1_permuted_weight"], + gmm1_permuted_weight_scale=kwargs["gmm1_permuted_weight_scale"], + gmm2_weight=kwargs["gmm2_weight"], + gmm2_weight_scale=kwargs["gmm2_weight_scale"], + num_max_dispatch_tokens_per_rank=self.num_max_dispatch_tokens_per_rank, + num_experts=self.num_experts, + ) + return FuseEPDispatchOutput(hidden_states) + + def combine(self, combine_input: CombineInput, **kwargs) -> torch.Tensor: + pass + + def _get_buffer(self): + DeepEPBuffer.set_dispatch_mode_as_low_latency() + return DeepEPBuffer.get_deepep_buffer( + self.group, + self.hidden_size, + self.params_bytes, + self.deepep_mode, + self.num_max_dispatch_tokens_per_rank, + self.num_experts, + ) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 066037db980..b4c789a84fa 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -106,6 +106,7 @@ def _kimi_k2_moe_fused_gate( raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") if _is_npu: import torch_npu + from sgl_kernel_npu.norm.l1_norm import l1_norm # -------------------------------- TopKConfig --------------------------------------- @@ -363,15 +364,14 @@ def forward_npu( router_logits, k=self.topk_config.top_k, ) - topk_weights = topk_weights.to(torch.float32) if renormalize: - topk_weights_sum = ( - topk_weights.sum(dim=-1, keepdim=True) + topk_weights = l1_norm( + topk_weights if self.topk_config.num_fused_shared_experts == 0 - else topk_weights[:, :-1].sum(dim=-1, keepdim=True) + else topk_weights[:, :-1] ) - topk_weights = topk_weights / topk_weights_sum + topk_weights = topk_weights.to(torch.float32) if expert_location_dispatch_info is not None: topk_ids = topk_ids_logical_to_physical( diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index 28805c070cc..85d28b148d7 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -24,6 +24,7 @@ class MoeA2ABackend(Enum): NONE = "none" DEEPEP = "deepep" MOONCAKE = "mooncake" + ASCEND_FUSEEP = "ascend_fuseep" @classmethod def _missing_(cls, value): @@ -43,6 +44,9 @@ def is_deepep(self): def is_mooncake(self): return self == MoeA2ABackend.MOONCAKE + def is_ascend_fuseep(self): + return self == MoeA2ABackend.ASCEND_FUSEEP + class MoeRunnerBackend(Enum): diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 3212f02cca5..afb7d4e6c87 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -32,6 +32,7 @@ from sglang.srt.utils import ( apply_module_patch, cpu_has_amx_support, + get_bool_env_var, is_cpu, is_cuda, is_npu, @@ -63,6 +64,8 @@ else: useMindIETurbo = True + from sgl_kernel_npu.norm.add_rmsnorm_bias import add_rmsnorm_bias + logger = logging.getLogger(__name__) @@ -87,10 +90,13 @@ def _rmsnorm_forward_oot( if not x.is_contiguous(): x = x.contiguous() if residual is not None: - out, _, residual_out = torch_npu.npu_add_rms_norm( - residual, x, self.weight.data, self.variance_epsilon + out, residual_out = add_rmsnorm_bias( + x, + residual, + self.weight.data, + self.bias, + self.variance_epsilon, ) - out = out + self.bias return out.to(x.dtype), residual_out out = torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0] @@ -1072,15 +1078,23 @@ def create_weights( layer.register_parameter("w2_weight_offset", w2_weight_offset) set_weight_attrs(w2_weight_offset, extra_weight_attrs) + def release_weight_cache(self, weight: torch.Tensor): + # .contiguous() introduces additional memory overhead and needs to be released using resize_(0) + origin_weight = weight.data.transpose(1, 2) + new_weight = origin_weight.contiguous() + origin_weight.untyped_storage().resize_(0) + return new_weight + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - layer.w13_weight = Parameter( - layer.w13_weight.data.transpose(1, 2).contiguous(), requires_grad=False - ) - layer.w2_weight = Parameter( - layer.w2_weight.data.transpose(1, 2).contiguous(), requires_grad=False - ) + weight_data = self.release_weight_cache(layer.w13_weight.data) + layer.w13_weight = Parameter(weight_data, requires_grad=False) + + weight_data = self.release_weight_cache(layer.w2_weight.data) + layer.w2_weight = Parameter(weight_data, requires_grad=False) + layer.w13_weight_scale = Parameter( - layer.w13_weight_scale.data.squeeze(-1).contiguous(), requires_grad=False + layer.w13_weight_scale.data.squeeze(-1).contiguous().to(torch.float32), + requires_grad=False, ) layer.w2_weight_scale = Parameter( layer.w2_weight_scale.data.squeeze(-1).contiguous(), requires_grad=False @@ -1092,6 +1106,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False ) + if get_bool_env_var("ENABLE_ASCEND_MOE_NZ"): + layer.w13_weight.data = torch_npu.npu_format_cast(layer.w13_weight.data, 29) + layer.w2_weight.data = torch_npu.npu_format_cast(layer.w2_weight.data, 29) + def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): @@ -1145,7 +1163,7 @@ def apply_without_routing_weights( # act_fn: swiglu hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( x=hidden_states, - weight_scale=layer.w13_weight_scale.to(torch.float32), + weight_scale=layer.w13_weight_scale, activation_scale=hidden_states_scale, bias=None, quant_scale=None, diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index b14ceaed17f..0cdb7e1ae8c 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -135,6 +135,7 @@ def __init__( self._apply_rotary_emb_wrapped = torch.compile(dynamic=True)( self._apply_rotary_emb_wrapped ) + self.position_cos, self.position_sin = None, None def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: """Compute the inverse frequency.""" @@ -202,6 +203,18 @@ def _ensure_cos_sin_cache_length(self, needed_max_pos: int): device=device, dtype=dtype ) + def get_cos_sin_with_position(self, positions): + cos_sin = self.cos_sin_cache.index_select(0, positions.flatten()) + last_dim = cos_sin.size()[-1] + cos, sin = ( + cos_sin.reshape(-1, 2, last_dim // 2).repeat(1, 1, 2).chunk(2, dim=-2) + ) + # BSNH + self.position_cos, self.position_sin = ( + cos.view(-1, 1, 1, last_dim).contiguous(), + sin.view(-1, 1, 1, last_dim).contiguous(), + ) + def forward_native( self, positions: torch.Tensor, diff --git a/python/sglang/srt/model_executor/npu_graph_runner.py b/python/sglang/srt/model_executor/npu_graph_runner.py index abc276ff158..aa9241a7715 100644 --- a/python/sglang/srt/model_executor/npu_graph_runner.py +++ b/python/sglang/srt/model_executor/npu_graph_runner.py @@ -19,11 +19,13 @@ import os import threading from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Dict, Optional, Union +import numpy as np import torch -from sglang.srt.configs.model_config import is_deepseek_nsa +from sglang.srt.configs.model_config import AttentionArch, is_deepseek_nsa +from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.utils import is_npu @@ -47,6 +49,20 @@ class NPUGraphRunner(CudaGraphRunner): def __init__(self, model_runner: ModelRunner): super().__init__(model_runner) + self.update_attr_name = None + self.update_attr_type = None + self.model_runner = model_runner + self._init_arch_map() + + def _init_arch_map(self): + self.attr_name: Dict[str, str] = { + AttentionArch.MLA: "actual_seq_lengths_kv", + AttentionArch.MHA: "context_lens", + } + self.attr_type: Dict[str, Union[list, torch.Tensor]] = { + AttentionArch.MLA: [], + AttentionArch.MHA: torch.Tensor(), + } def _create_device_graph(self): return torch.npu.NPUGraph() @@ -61,9 +77,22 @@ def _capture_graph(self, graph, pool, stream, run_once_fn): out = run_once_fn() return out + def _get_update_attr_name(self, model_runner): + if self.bs < get_attention_tp_size(): + return self.attr_name[AttentionArch.MLA] + return self.attr_name[model_runner.model_config.attention_arch] + + def _get_update_attr_type(self, model_runner): + if self.bs < get_attention_tp_size(): + return self.attr_type[AttentionArch.MLA] + return self.attr_type[model_runner.model_config.attention_arch] + def _update_inputs(self, seq_lens): + if isinstance(self.update_attr_type, torch.Tensor): + seq_lens = torch.from_numpy(np.array(seq_lens).astype(np.int32)) + self.graphs[self.bs].update( - cpu_update_input=[{"actual_seq_lengths_kv": seq_lens}] + cpu_update_input=[{self.update_attr_name: seq_lens}] ) def _cache_loc_dtype(self): @@ -110,6 +139,8 @@ def replay( self.buffers.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids) self.buffers.positions[: self.raw_num_token].copy_(forward_batch.positions) + self.update_attr_name = self._get_update_attr_name(self.model_runner) + self.update_attr_type = self._get_update_attr_type(self.model_runner) # Replay if not is_deepseek_nsa(self.model_runner.model_config.hf_config): if forward_batch.forward_mode.is_target_verify(): diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 9a9ac4da8be..7d8603a0ee9 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -44,6 +44,9 @@ _is_cuda = is_cuda() _is_npu = is_npu() +if _is_npu: + from sgl_kernel_npu.norm.split_qkv_rmsnorm_rope import split_qkv_rmsnorm_rope + class Qwen3Attention(nn.Module): def __init__( @@ -161,6 +164,33 @@ def _apply_qk_norm( k = k_by_head.view(k.shape) return q, k + def forward_prepare_native(self, positions, hidden_states): + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb(positions, q, k) + return q, k, v + + def forward_prepare_npu(self, positions, hidden_states): + qkv, _ = self.qkv_proj(hidden_states) + + if self.attn.layer_id == 0: + self.rotary_emb.get_cos_sin_with_position(positions) + q, k, v = split_qkv_rmsnorm_rope( + qkv, + self.rotary_emb.position_sin, + self.rotary_emb.position_cos, + self.q_norm.weight, + self.k_norm.weight, + self.q_size, + self.kv_size, + self.head_dim, + self.q_norm.variance_epsilon, + q_bias=getattr(self.q_norm, "bias", None), + k_bias=getattr(self.k_norm, "bias", None), + ) + return q, k, v + def forward( self, positions: torch.Tensor, @@ -170,10 +200,16 @@ def forward( if get_global_server_args().rl_on_policy_target is not None: hidden_states = hidden_states.bfloat16() - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self._apply_qk_norm(q, k) - q, k = self.rotary_emb(positions, q, k) + if not _is_npu: + q, k, v = self.forward_prepare_native( + positions=positions, + hidden_states=hidden_states, + ) + else: + q, k, v = self.forward_prepare_npu( + positions=positions, + hidden_states=hidden_states, + ) if get_global_server_args().rl_on_policy_target is not None: q = q.to(torch.bfloat16) diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 3a14bc5fe34..3fbe8125729 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -70,6 +70,7 @@ is_cuda, is_flashinfer_available, is_non_idle_and_non_empty, + is_npu, ) Qwen3MoeConfig = None @@ -78,6 +79,10 @@ logger = logging.getLogger(__name__) _is_cuda = is_cuda() +_is_npu = is_npu() + +if _is_npu: + from sgl_kernel_npu.norm.split_qkv_rmsnorm_rope import split_qkv_rmsnorm_rope class Qwen3MoeSparseMoeBlock(nn.Module): @@ -139,7 +144,10 @@ def forward( use_reduce_scatter: bool = False, ) -> torch.Tensor: - if not get_moe_a2a_backend().is_deepep(): + if ( + not get_moe_a2a_backend().is_deepep() + and not get_moe_a2a_backend().is_ascend_fuseep() + ): return self.forward_normal( hidden_states, should_allreduce_fusion, use_reduce_scatter ) @@ -392,14 +400,37 @@ def op_core(self, state): state.pop("attn_intermediate_state") ) - def forward_prepare( + def forward_prepare_npu( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ): + qkv, _ = self.qkv_proj(hidden_states) + if self.attn.layer_id == 0: + self.rotary_emb.get_cos_sin_with_position(positions) + q, k, v = split_qkv_rmsnorm_rope( + qkv, + self.rotary_emb.position_sin, + self.rotary_emb.position_cos, + self.q_norm.weight, + self.k_norm.weight, + self.q_size, + self.kv_size, + self.head_dim, + self.q_norm.variance_epsilon, + q_bias=getattr(self.q_norm, "bias", None), + k_bias=getattr(self.k_norm, "bias", None), + ) + inner_state = q, k, v, forward_batch + return None, forward_batch, inner_state + + def forward_prepare_native( self, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, ): - if hidden_states.shape[0] == 0: - return hidden_states, forward_batch, None qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self._apply_qk_norm(q, k) @@ -421,6 +452,27 @@ def forward_prepare( inner_state = q, k, v, forward_batch return None, forward_batch, inner_state + def forward_prepare( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ): + if hidden_states.shape[0] == 0: + return hidden_states, forward_batch, None + if not _is_npu: + return self.forward_prepare_native( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + else: + return self.forward_prepare_npu( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + def forward_core(self, intermediate_state): hidden_states, forward_batch, inner_state = intermediate_state if inner_state is None: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0c1b9634e5a..4a06fdc8092 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -404,7 +404,7 @@ class ServerArgs: # Expert parallelism ep_size: int = 1 - moe_a2a_backend: Literal["none", "deepep", "mooncake"] = "none" + moe_a2a_backend: Literal["none", "deepep", "mooncake", "ascend_fuseep"] = "none" moe_runner_backend: str = "auto" flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default" enable_flashinfer_allreduce_fusion: bool = False @@ -1516,6 +1516,12 @@ def _handle_a2a_moe(self): f"Mooncake MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." ) + if self.moe_a2a_backend == "ascend_fuseep": + self.ep_size = self.tp_size + logger.warning( + f"Ascend fused EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." + ) + def _handle_eplb_and_dispatch(self): if self.enable_eplb and (self.expert_distribution_recorder_mode is None): self.expert_distribution_recorder_mode = "stat" @@ -1523,7 +1529,7 @@ def _handle_eplb_and_dispatch(self): "EPLB is enabled. The expert_distribution_recorder_mode is automatically set." ) - if (self.enable_eplb or (self.init_expert_location is not None)) and ( + if (self.enable_eplb or (self.init_expert_location != "trivial")) and ( self.ep_dispatch_algorithm is None ): self.ep_dispatch_algorithm = "static" @@ -2969,7 +2975,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--moe-a2a-backend", type=str, - choices=["none", "deepep", "mooncake"], + choices=["none", "deepep", "mooncake", "ascend_fuseep"], default=ServerArgs.moe_a2a_backend, help="Choose the backend for MoE A2A.", ) diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index a7e25799b98..656dfcaabe6 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -633,38 +633,48 @@ def get_cmo_stream(): AIV or communication kernels, aiming to overlap the memory access time. """ global cmo_stream - if cmo_stream is None: - cmo_stream = torch.get_device_module().Stream() return cmo_stream -def prepare_weight_cache(handle, cache): +def set_cmo_stream(stream): + global cmo_stream + cmo_stream = stream + + +def prepare_weight_cache(handle, cache, PREFETCH_MAX_SIZE=1000000000): + """ + PREFETCH_MAX_SIZE: maximum size (bytes) for each prefetch operation. + This affects the time spent in prefetch: + time ≈ PREFETCH_MAX_SIZE / system_bandwidth + """ import torch_npu - NPU_PREFETCH_MAX_SIZE_BYTES = ( - 1000000000 # 1GB, a large value to prefetch entire weight - ) stream = get_cmo_stream() - stream.wait_stream(torch.npu.current_stream()) - with torch.npu.stream(stream): + if stream is None: + stream = torch.get_device_module().Stream() + set_cmo_stream(stream) + stream.wait_stream(torch.get_device_module().current_stream()) + with torch.get_device_module().stream(stream): if isinstance(cache, list): for weight in cache: torch_npu.npu_prefetch( weight, handle, - NPU_PREFETCH_MAX_SIZE_BYTES, + PREFETCH_MAX_SIZE, ) else: torch_npu.npu_prefetch( cache, handle, - NPU_PREFETCH_MAX_SIZE_BYTES, + PREFETCH_MAX_SIZE, ) def wait_cmo_stream(): - cur_stream = torch.get_device_module().current_stream() - cur_stream.wait_stream(get_cmo_stream()) + stream = get_cmo_stream() + if stream is not None: + cur_stream = torch.get_device_module().current_stream() + cur_stream.wait_stream(stream) @lru_cache(maxsize=1) diff --git a/scripts/ci/npu_ci_install_dependency.sh b/scripts/ci/npu_ci_install_dependency.sh index ce092ed5b35..24192d6df35 100755 --- a/scripts/ci/npu_ci_install_dependency.sh +++ b/scripts/ci/npu_ci_install_dependency.sh @@ -21,7 +21,7 @@ apt update -y && apt install -y \ update-ca-certificates ${PIP_INSTALL} --upgrade pip # Pin wheel to 0.45.1, REF: https://github.com/pypa/wheel/issues/662 -${PIP_INSTALL} wheel==0.45.1 +${PIP_INSTALL} wheel==0.45.1 pybind11 ### Install MemFabric @@ -33,22 +33,18 @@ PYTORCH_VERSION="2.8.0" TORCHVISION_VERSION="0.23.0" ${PIP_INSTALL} torch==${PYTORCH_VERSION} torchvision==${TORCHVISION_VERSION} --index-url https://download.pytorch.org/whl/cpu -PTA_VERSION="v7.2.0-pytorch${PYTORCH_VERSION}" -PTA_NAME="torch_npu-${PYTORCH_VERSION}-cp311-cp311-manylinux_2_28_aarch64.whl" -PTA_URL="https://gitcode.com/Ascend/pytorch/releases/download/${PTA_VERSION}/${PTA_NAME}" -wget -O "${PTA_NAME}" "${PTA_URL}" && ${PIP_INSTALL} "./${PTA_NAME}" +PTA_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/torch_npu/torch_npu-2.8.0.post2.dev20251113-cp311-cp311-manylinux_2_28_aarch64.whl" +${PIP_INSTALL} ${PTA_URL} ### Install Triton-Ascend -TRITON_ASCEND_NAME="triton_ascend-3.2.0+gitb0ea0850-cp311-cp311-linux_aarch64.whl" -TRITON_ASCEND_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/triton_ascend-3.2.0%2Bgitb0ea0850-cp311-cp311-linux_aarch64.whl" -${PIP_INSTALL} attrs==24.2.0 numpy==1.26.4 scipy==1.13.1 decorator==5.1.1 psutil==6.0.0 pytest==8.3.2 pytest-xdist==3.6.1 pyyaml pybind11 -wget -O "${TRITON_ASCEND_NAME}" "${TRITON_ASCEND_URL}" && ${PIP_INSTALL} "./${TRITON_ASCEND_NAME}" +TRITON_ASCEND_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/triton_ascend/triton_ascend-3.2.0.dev2025112116-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl" +${PIP_INSTALL} ${TRITON_ASCEND_URL} ### Install BiSheng -BISHENG_NAME="Ascend-BiSheng-toolkit_aarch64.run" -BISHENG_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/${BISHENG_NAME}" +BISHENG_NAME="Ascend-BiSheng-toolkit_aarch64_20251121.run" +BISHENG_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/triton_ascend/${BISHENG_NAME}" wget -O "${BISHENG_NAME}" "${BISHENG_URL}" && chmod a+x "${BISHENG_NAME}" && "./${BISHENG_NAME}" --install && rm "${BISHENG_NAME}"