Skip to content
Open
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
4ce70e6
NPU Graph Compilation support and PassManager with AddRmsNorm & Quantize
eshoguli Sep 17, 2025
b974460
pre-commit & refactoring
eshoguli Oct 31, 2025
7a7bde7
pre-commit
qyqc731 Nov 1, 2025
d8e2dc3
Merge branch 'main' into eshogulin/pass_manager
eshoguli Nov 5, 2025
9bb7751
Merge branch 'main' into eshogulin/pass_manager: fix - custom_ops.py
eshoguli Nov 5, 2025
7048005
cleanup & refactoring
eshoguli Nov 10, 2025
eb240d9
Pass Manager fix
eshoguli Nov 10, 2025
29c1d89
Compilation: refactoring
eshoguli Nov 11, 2025
3e98d17
NPU Piecewise Graph
eshoguli Nov 8, 2025
3d9516a
rollback
eshoguli Nov 11, 2025
2c1b6fe
linter
eshoguli Nov 11, 2025
55016b0
refactoring
eshoguli Nov 11, 2025
fbff08d
refactoring
eshoguli Nov 11, 2025
3e5db77
Compilation: refactoring
eshoguli Nov 11, 2025
99d4497
Merge branch 'main' into eshogulin/pass_manager
eshoguli Nov 12, 2025
30da7fe
model_type check
eshoguli Nov 13, 2025
36ef7e7
Merge branch 'main' into eshogulin/pass_manager
eshoguli Nov 14, 2025
1808479
PiecewiseNpuGraphCompilerBackend quick fix
Nov 14, 2025
bcfc2c5
CompilationConfig reusage
Nov 17, 2025
a6a159d
--torch-compile-max-bs support
Nov 18, 2025
c08d076
TorchAir compilation support
XDaoHong Nov 14, 2025
73f2ee9
runner selection fix: model forward usage
eshoguli Nov 19, 2025
2f97641
add test for torchair
XDaoHong Nov 19, 2025
7154cf4
TorchAir compilation support: refactoring
eshoguli Nov 19, 2025
dfaee00
NPU Piecewise Graph: refactoring
eshoguli Nov 19, 2025
bec1b28
Merge branch 'main' into eshogulin/pass_manager
eshoguli Nov 19, 2025
253c14d
linter fix after merge commit
eshoguli Nov 19, 2025
85d808e
NPUGraph compilation (fp16) & NPU Piecewise Graph tests
eshoguli Nov 19, 2025
11074d9
TorchAir compilation support: refactoring 2
eshoguli Nov 19, 2025
51ac4b4
Merge branch 'main' into eshogulin/pass_manager
ping1jing2 Nov 20, 2025
e06675b
CompilationConfig comments fix + linter fix
eshoguli Nov 21, 2025
0c09c24
backend instantiation in get_compiler_backend
eshoguli Nov 21, 2025
00a0b9b
Merge branch 'main' into eshogulin/pass_manager
eshoguli Nov 23, 2025
0b31746
Merge remote-tracking branch 'sglang/main' into eshogulin/pass_manage…
eshoguli Nov 25, 2025
3b5c83b
Merge branch 'main' into eshogulin/pass_manager
eshoguli Nov 26, 2025
7eefeee
linter fix
eshoguli Nov 26, 2025
8c63980
dynamo patch removing
eshoguli Nov 26, 2025
2e02568
fix on main branch: compilation
eshoguli Nov 27, 2025
966bbf4
Merge branch 'main' into eshogulin/pass_manager
eshoguli Nov 27, 2025
14092b3
auto merge fix
eshoguli Nov 27, 2025
f989147
tests suit update
eshoguli Nov 27, 2025
bf1251d
Add npu_add_rms_norm_dynamic_quant fuse
OrangeRedeng Nov 27, 2025
317174b
Merge branch 'eshogulin/pass_manager' of https://github.com/eshoguli/…
OrangeRedeng Nov 27, 2025
e6eb29c
NPU Graph compilation: attention architecture check
eshoguli Nov 27, 2025
caba95e
Add npu_add_rms_norm_dynamic_quant fuse: quick fix
eshoguli Nov 27, 2025
3f87879
Qwen3 MoE compilation support for NPU
eshoguli Nov 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion python/sglang/srt/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,47 @@

import torch

from sglang.srt.utils import is_hip, is_hpu, is_npu
from sglang.srt.utils import direct_register_custom_op, is_hip, is_hpu, is_npu

logger = logging.getLogger(__name__)


import sglang.srt.utils


@torch.library.custom_op("sglang::wait_cmo_stream", mutates_args=())
def wait_cmo_stream() -> None:
if sglang.srt.utils.get_cmo_stream():
sglang.srt.utils.wait_cmo_stream()


@wait_cmo_stream.register_fake
def wait_cmo_stream_fake() -> None:
pass


def get_cmo_stream() -> bool:
return True


def prepare_weight_cache(handle: torch.Tensor, cache: List[torch.Tensor]) -> None:
sglang.srt.utils.prepare_weight_cache(handle, cache)


def prepare_weight_cache_register_fake(
handle: torch.Tensor, cache: List[torch.Tensor]
) -> None:
pass


direct_register_custom_op(
op_name="prepare_weight_cache",
op_func=prepare_weight_cache,
mutates_args=["handle"],
fake_impl=prepare_weight_cache_register_fake,
)


if not is_hpu():
try:
import sgl_kernel
Expand Down
18 changes: 16 additions & 2 deletions python/sglang/srt/compilation/compilation_config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_config.py

from typing import List
import json
from typing import List, Optional


# TODO(Yuwei): support better compile config support
class CompilationConfig:
def __init__(self, capture_sizes: List[int], compiler: str = "eager"):
splitting_ops: Optional[list[str]] = None

def __init__(
self,
capture_sizes: List[int] = [],
compiler: str = "eager",
splitting_ops: list[str] = [],
):
self.traced_files = set()
self.capture_sizes = capture_sizes
self.compiler = compiler
self.splitting_ops = splitting_ops

def add_traced_file(self, file_path: str):
self.traced_files.add(file_path)
Expand All @@ -18,3 +27,8 @@ def get_traced_files(self):

def get_capture_sizes(self):
return self.capture_sizes

@classmethod
def from_cli(cls, args) -> "CompilationConfig":
args_dict = json.loads(args)
return CompilationConfig(**args_dict)
52 changes: 52 additions & 0 deletions python/sglang/srt/compilation/custom_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import List, Optional

import torch

import sglang.srt.layers.dp_attention


@torch.library.custom_op("sglang::_set_dp_buffer_len", mutates_args=())
def _set_dp_buffer_len(
global_dp_buffer_len: Optional[int],
num_tokens: Optional[int],
is_max_len: bool,
global_num_tokens: Optional[List[int]] = None,
) -> None:
global set_dp_buffer_len_original
sglang.srt.layers.dp_attention.set_dp_buffer_len(
global_dp_buffer_len, num_tokens, is_max_len, global_num_tokens
)


@_set_dp_buffer_len.register_fake
def _set_dp_buffer_len_fake(
global_dp_buffer_len: Optional[int],
num_tokens: Optional[int],
is_max_len: bool,
global_num_tokens: Optional[List[int]] = None,
) -> None:
pass


@torch.library.custom_op("sglang::_set_is_extend_in_batch", mutates_args=())
def _set_is_extend_in_batch(is_extend_in_batch: bool) -> None:
sglang.srt.layers.dp_attention.set_is_extend_in_batch(is_extend_in_batch)


@_set_is_extend_in_batch.register_fake
def _set_is_extend_in_batch_fake(is_extend_in_batch: bool) -> None:
pass
20 changes: 20 additions & 0 deletions python/sglang/srt/compilation/npu/compilation_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2023-2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import torch_npu


class CompilationContext:
graph_memory_pool = None
stream: torch_npu.npu.Stream = None
29 changes: 29 additions & 0 deletions python/sglang/srt/compilation/npu/npu_graph_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import torch

from sglang.srt.compilation.npu.npu_graph_compiler_backend import (
NpuGraphCompilerBackend,
)


class NpuGraphCompiler:
def __init__(self, model: torch.nn.Module, model_type: torch.dtype):
torch._dynamo.reset()

self.backend = NpuGraphCompilerBackend(model_type)
self.compiled_callable = torch.compile(
model, fullgraph=True, dynamic=False, backend=self.backend
)
46 changes: 46 additions & 0 deletions python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Callable

import torch
from torch._dynamo.eval_frame import DisableContext

from sglang.srt.compilation.npu.pass_manager import PassManager
from sglang.srt.compilation.npu.passes.w8a8_int8 import (
DivFuse,
EraseCopy,
NpuAddRmsNormQuantFuse,
)


class NpuGraphCompilerBackend:
def __init__(self, model_type: torch.dtype):
self.model_type = model_type

def __call__(self, graph: torch.fx.GraphModule, example_inputs) -> Callable:
DisableContext.compiled_function_args[DisableContext.batch_size] = (
example_inputs
)
if self.model_type == torch.bfloat16:
NpuGraphCompilerBackend.apply_passes(graph)
return graph

def apply_passes(graph_module: torch.fx.GraphModule):
passManager = PassManager(graph_module)
passManager.add(NpuAddRmsNormQuantFuse)
passManager.add(DivFuse)
passManager.add(EraseCopy)
passManager.apply()
graph_module.recompile()
46 changes: 46 additions & 0 deletions python/sglang/srt/compilation/npu/pass_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import torch


class PassManager:
def __init__(self, graph_module: torch.fx.GraphModule):
self.graph_module = graph_module
self.passes = []

def add(self, pass_):
self.passes.append(pass_)

def apply(self):
updated = False
for pass_ in self.passes:
pass_instance = pass_()
results = []
try:
if callable(pass_instance):
results = pass_instance(self.graph_module)
else:
results = torch.fx.replace_pattern(
self.graph_module, pass_.pattern, pass_.replacement
)
except:
# pass was not applied
pass

if not updated:
updated = len(results) != 0

if updated:
self.graph_module.recompile()
100 changes: 100 additions & 0 deletions python/sglang/srt/compilation/npu/passes/w8a8_int8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import torch


class DivFuse:
def pattern(x):
y = 1.0 / x
z = 1.0 / y
return z

def replacement(x):
return x


class EraseCopy:
def __call__(self, graph_module: torch.fx.GraphModule):
copy_node = None
prepare_weight_cache_default_node = None

results = []
for module in graph_module.modules():
for node in list(module.graph.nodes):
if node.type == torch.nn.parameter.Parameter:
continue
if node.target == "copy_":
copy_node = node
prepare_weight_cache_default_node = None
continue

if (
copy_node
and node.target == torch.ops.sglang.prepare_weight_cache.default
):
prepare_weight_cache_default_node = node
continue

if copy_node and node.target == torch.ops.npu.npu_add_rms_norm_quant:
arg = copy_node.args[1]

if prepare_weight_cache_default_node is not None:
prepare_weight_cache_default_node.args = (
arg,
prepare_weight_cache_default_node.args[1],
)

node.args = (
node.args[0],
arg,
node.args[2],
node.args[3],
node.args[4],
)

module.graph.erase_node(copy_node)

result = (
arg,
copy_node,
prepare_weight_cache_default_node,
)
results.append(result)

copy_node = None
prepare_weight_cache_default_node = None

return results


class NpuAddRmsNormQuantFuse:
def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset, v1, v2, v3):
output = torch.ops.npu.npu_add_rms_norm(
rms_norm_input, residual, rms_norm_weight, 1e-6
)
out0 = output[0]
out2 = output[2]
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, v1, v2, v3)
return quantized_output, out2

def replacement(
rms_norm_input, residual, rms_norm_weight, scale, offset, v1, v2, v3
):
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input, residual, rms_norm_weight, 1.0 / scale, offset, epsilon=1e-6
)
quantized_output = output[0]
out2 = output[2]
return quantized_output, out2
Loading
Loading