Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions docker/npu.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@ ARG PIP_INDEX_URL="https://pypi.org/simple/"
ARG APTMIRROR=""
ARG PYTORCH_VERSION="2.8.0"
ARG TORCHVISION_VERSION="0.23.0"
ARG PTA_VERSION="v7.2.0-pytorch${PYTORCH_VERSION}"
ARG PTA_NAME="torch_npu-${PYTORCH_VERSION}-cp311-cp311-manylinux_2_28_aarch64.whl"
ARG PTA_URL="https://gitcode.com/Ascend/pytorch/releases/download/${PTA_VERSION}/${PTA_NAME}"
ARG TRITON_ASCEND_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/triton_ascend-3.2.0%2Bgitb0ea0850-cp311-cp311-linux_aarch64.whl"
ARG BISHENG_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/Ascend-BiSheng-toolkit_aarch64.run"
ARG PTA_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/torch_npu/torch_npu-2.8.0.post2.dev20251113-cp311-cp311-manylinux_2_28_aarch64.whl"
ARG TRITON_ASCEND_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/triton_ascend/triton_ascend-3.2.0.dev2025112116-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl"
ARG BISHENG_NAME="Ascend-BiSheng-toolkit_aarch64_20251121.run"
ARG BISHENG_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/triton_ascend/${BISHENG_NAME}"
ARG SGLANG_TAG=main
ARG ASCEND_CANN_PATH=/usr/local/Ascend/ascend-toolkit
ARG SGLANG_KERNEL_NPU_TAG=main
Expand Down Expand Up @@ -64,13 +63,13 @@ RUN ${PIP_INSTALL} sglang-router


### Install PyTorch and PTA
RUN (${PIP_INSTALL} torch==${PYTORCH_VERSION} torchvision==${TORCHVISION_VERSION} --index-url https://download.pytorch.org/whl/cpu) && \
(wget -O "${PTA_NAME}" "${PTA_URL}" && ${PIP_INSTALL} "./${PTA_NAME}" && rm "./${PTA_NAME}")
RUN (${PIP_INSTALL} torch==${PYTORCH_VERSION} torchvision==${TORCHVISION_VERSION} --index-url https://download.pytorch.org/whl/cpu) \
&& (${PIP_INSTALL} ${PTA_URL})


# TODO: install from pypi released triton-ascend
RUN ${PIP_INSTALL} attrs==24.2.0 numpy==1.26.4 scipy==1.13.1 decorator==5.1.1 psutil==6.0.0 pytest==8.3.2 pytest-xdist==3.6.1 pyyaml pybind11 && \
${PIP_INSTALL} ${TRITON_ASCEND_URL}
RUN (${PIP_INSTALL} pybind11) \
&& (${PIP_INSTALL} ${TRITON_ASCEND_URL})

# Install SGLang
RUN git clone https://github.com/sgl-project/sglang --branch $SGLANG_TAG && \
Expand All @@ -96,6 +95,6 @@ RUN wget https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/ops/CANN-custom_o
${PIP_INSTALL} ./custom_ops-1.0.$DEVICE_TYPE-cp311-cp311-linux_aarch64.whl

# Install Bisheng
RUN wget ${BISHENG_URL} && chmod a+x Ascend-BiSheng-toolkit_aarch64.run && ./Ascend-BiSheng-toolkit_aarch64.run --install && rm Ascend-BiSheng-toolkit_aarch64.run
RUN wget -O "${BISHENG_NAME}" "${BISHENG_URL}" && chmod a+x "${BISHENG_NAME}" && "./${BISHENG_NAME}" --install && rm "${BISHENG_NAME}"

CMD ["/bin/bash"]
130 changes: 85 additions & 45 deletions python/sglang/srt/layers/attention/ascend_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,53 +625,93 @@ def forward_decode_graph(
)

if not self.use_mla:
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
layer.layer_id
).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
layer.layer_id
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
query = q.reshape(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
if self.forward_metadata.seq_lens_cpu_int is None:
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
num_tokens = q.shape[0]
"""PA will support bs<tp in the later version of CANN"""
if num_tokens < get_attention_tp_size():
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
layer.layer_id
).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
layer.layer_id
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
query = q.reshape(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
if self.forward_metadata.seq_lens_cpu_int is None:
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
else:
actual_seq_len_kv = (
self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
)
num_tokens = query.shape[0]
workspace = (
torch_npu._npu_fused_infer_attention_score_get_max_workspace(
query,
k_cache,
v_cache,
block_table=self.forward_metadata.block_tables,
block_size=self.page_size,
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
input_layout="BSH",
scale=layer.scaling,
actual_seq_lengths_kv=actual_seq_len_kv,
)
)
output = torch.empty(
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
dtype=q.dtype,
device=q.device,
)
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
torch_npu.npu_fused_infer_attention_score.out(
query,
k_cache,
v_cache,
block_table=self.forward_metadata.block_tables,
block_size=self.page_size,
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
input_layout="BSH",
scale=layer.scaling,
actual_seq_lengths_kv=actual_seq_len_kv,
workspace=workspace,
out=[output, softmax_lse],
)
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
else:
actual_seq_len_kv = (
self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
layer.layer_id
)
query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)
num_tokens = query.shape[0]
attn_output = torch.empty(
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
dtype=query.dtype,
device=query.device,
)
if self.forward_metadata.seq_lens_cpu_int is None:
actual_seq_len_kv = torch.from_numpy(
np.array(self.forward_metadata.seq_lens_cpu_list).astype(
np.int32
)
)
else:
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_int

torch_npu._npu_paged_attention(
query=query,
key_cache=k_cache,
value_cache=v_cache,
num_heads=layer.tp_q_head_num,
num_kv_heads=layer.tp_k_head_num,
scale_value=layer.scaling,
block_table=self.forward_metadata.block_tables,
context_lens=actual_seq_len_kv,
out=attn_output,
)
return attn_output.view(
num_tokens, layer.tp_q_head_num * layer.v_head_dim
)
num_tokens = query.shape[0]
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
query,
k_cache,
v_cache,
block_table=self.forward_metadata.block_tables,
block_size=self.page_size,
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
input_layout="BSH",
scale=layer.scaling,
actual_seq_lengths_kv=actual_seq_len_kv,
)
output = torch.empty(
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
dtype=q.dtype,
device=q.device,
)
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
torch_npu.npu_fused_infer_attention_score.out(
query,
k_cache,
v_cache,
block_table=self.forward_metadata.block_tables,
block_size=self.page_size,
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
input_layout="BSH",
scale=layer.scaling,
actual_seq_lengths_kv=actual_seq_len_kv,
workspace=workspace,
out=[output, softmax_lse],
)
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
else:
c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
k_rope_cache = k_rope.view(
Expand Down
137 changes: 137 additions & 0 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
logger = logging.getLogger(__name__)


if _is_npu:
import torch_npu


class DeepEPMoE(FusedMoE):
"""
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
Expand Down Expand Up @@ -411,9 +415,142 @@ def npu_fused_moe_without_routing_weights_bf16(
return hidden_states


class NpuFuseEPMoE(DeepEPMoE):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to separate files

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will be solved in #13359

def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
layer_id: int,
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
):
super().__init__(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
layer_id=layer_id,
num_fused_shared_experts=num_fused_shared_experts,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
activation=activation,
routed_scaling_factor=routed_scaling_factor,
)

self.quant_method.process_weights_after_loading = (
self._process_weights_after_loading
)

def forward(
self,
hidden_states: torch.Tensor,
topk_output: TopKOutput,
forward_shared_experts=None,
alt_stream=None,
disable_sbo=False,
):
return self.dispatcher.dispatch(
hidden_states=hidden_states,
topk_output=topk_output,
gmm1_permuted_weight=self.w13_weight,
gmm1_permuted_weight_scale=self.w13_weight_scale,
gmm2_weight=self.w2_weight,
gmm2_weight_scale=self.w2_weight_scale,
).hidden_state

def release_weight_cache(self, weight: torch.Tensor):
# .contiguous() introduces additional memory overhead and needs to be released using resize_(0)
origin_weight = weight.data.transpose(1, 2)
new_weight = origin_weight.contiguous()
origin_weight.untyped_storage().resize_(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed this workaround, we have fix in this PR #11984
commit ff1a3bb

Copy link
Collaborator

@iforgetmyname iforgetmyname Nov 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i can't really appreciate anything that does fix this issue there

using flatten is just another workaround from my view

return new_weight

def permute_w13_weight_scale(self, w: torch.Tensor, tile_n: int):
if tile_n % 2 != 0:
raise ValueError(f"tile_n must be even, got {tile_n}")

*dims, n = w.shape
if n % tile_n != 0:
raise ValueError(f"Last dimension {n} must be divisible by tile_n {tile_n}")

w_reshaped = w.reshape(*dims, 2, n // tile_n, tile_n // 2)

# Permute the last two dimensions.
perm_order = list(range(len(dims))) + [-2, -3, -1]
w_permuted = w_reshaped.permute(perm_order)

return w_permuted.reshape(*dims, n)

def reshape_w13_weight(self, weight: torch.Tensor, dim: int, chunk_size: int = 64):
# Achieving greater computing power through reshape on Ascend.
original_shape = weight.shape
if dim < 0:
dim += len(original_shape)

if original_shape[dim] % (2 * chunk_size) != 0:
raise ValueError(
f"Dimension {dim} size {original_shape[dim]} must be divisible by {2 * chunk_size}"
)

new_shape = (
*original_shape[:dim],
2,
original_shape[dim] // (2 * chunk_size),
chunk_size,
*original_shape[dim + 1 :],
)

weight = weight.view(new_shape)
weight = weight.transpose(dim, dim + 1).contiguous()

return weight.view(*original_shape[:dim], -1, *original_shape[dim + 1 :])

def _process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w13 = self.release_weight_cache(layer.w13_weight)
torch_npu.npu_format_cast_(w13, 2)
cpu_w13 = w13.cpu()
w13 = self.reshape_w13_weight(cpu_w13, -1).npu()
torch_npu.npu_format_cast_(w13, 29)
layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False)

w2 = torch_npu.npu_format_cast(layer.w2_weight.data, 29)
layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)

w13_scale = layer.w13_weight_scale.data.squeeze(-1).contiguous()
w13_scale = self.permute_w13_weight_scale(w13_scale, 128)
layer.w13_weight_scale = torch.nn.Parameter(
w13_scale.to(torch.float32), requires_grad=False
)

w2_scale = layer.w2_weight_scale.data.squeeze(-1).contiguous()
layer.w2_weight_scale = torch.nn.Parameter(
w2_scale.to(torch.float32), requires_grad=False
)

if hasattr(layer, "w13_weight_offset"):
layer.w13_weight_offset = torch.nn.Parameter(
layer.w13_weight_offset.data.squeeze(-1).contiguous(),
requires_grad=False,
)
if hasattr(layer, "w2_weight_offset"):
layer.w2_weight_offset = torch.nn.Parameter(
layer.w2_weight_offset.data.squeeze(-1).contiguous(),
requires_grad=False,
)


def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
return DeepEPMoE
if get_moe_a2a_backend().is_ascend_fuseep():
return NpuFuseEPMoE

# NEW: Direct FP4 detection (bypasses EP requirements)
# Check for FP4 quantization with TRTLLM flag, regardless of EP
Expand Down
12 changes: 12 additions & 0 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher:
async_finish=True,
return_recv_hook=True,
)
elif a2a_backend.is_ascend_fuseep():
from sglang.srt.layers.moe.token_dispatcher import NpuFuseEPDispatcher

return NpuFuseEPDispatcher(
group=get_tp_group().device_group,
router_topk=moe_runner_config.top_k,
permute_fusion=True,
num_experts=moe_runner_config.num_experts,
num_local_experts=moe_runner_config.num_local_experts,
hidden_size=moe_runner_config.hidden_size,
params_dtype=moe_runner_config.params_dtype,
)
else:
raise NotImplementedError(f"Unsupported a2a backend: {a2a_backend}")

Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/layers/moe/token_dispatcher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
DeepEPNormalCombineInput,
DeepEPNormalDispatchOutput,
)
from sglang.srt.layers.moe.token_dispatcher.fuseep import NpuFuseEPDispatcher
from sglang.srt.layers.moe.token_dispatcher.mooncake import (
MooncakeCombineInput,
MooncakeDispatchOutput,
Expand Down Expand Up @@ -48,4 +49,5 @@
"DeepEPLLDispatchOutput",
"DeepEPLLCombineInput",
"DeepEPNormalCombineInput",
"NpuFuseEPDispatcher",
]
Loading
Loading