-
Notifications
You must be signed in to change notification settings - Fork 3.6k
NPU Graph Compilation support and PassManager with AddRmsNorm & Quantize fuse. TorchAir compiler backend support. #11104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
eshoguli
wants to merge
46
commits into
sgl-project:main
Choose a base branch
from
eshoguli:eshogulin/pass_manager
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 30 commits
Commits
Show all changes
46 commits
Select commit
Hold shift + click to select a range
4ce70e6
NPU Graph Compilation support and PassManager with AddRmsNorm & Quantize
eshoguli b974460
pre-commit & refactoring
eshoguli 7a7bde7
pre-commit
qyqc731 d8e2dc3
Merge branch 'main' into eshogulin/pass_manager
eshoguli 9bb7751
Merge branch 'main' into eshogulin/pass_manager: fix - custom_ops.py
eshoguli 7048005
cleanup & refactoring
eshoguli eb240d9
Pass Manager fix
eshoguli 29c1d89
Compilation: refactoring
eshoguli 3e98d17
NPU Piecewise Graph
eshoguli 3d9516a
rollback
eshoguli 2c1b6fe
linter
eshoguli 55016b0
refactoring
eshoguli fbff08d
refactoring
eshoguli 3e5db77
Compilation: refactoring
eshoguli 99d4497
Merge branch 'main' into eshogulin/pass_manager
eshoguli 30da7fe
model_type check
eshoguli 36ef7e7
Merge branch 'main' into eshogulin/pass_manager
eshoguli 1808479
PiecewiseNpuGraphCompilerBackend quick fix
bcfc2c5
CompilationConfig reusage
a6a159d
--torch-compile-max-bs support
c08d076
TorchAir compilation support
XDaoHong 73f2ee9
runner selection fix: model forward usage
eshoguli 2f97641
add test for torchair
XDaoHong 7154cf4
TorchAir compilation support: refactoring
eshoguli dfaee00
NPU Piecewise Graph: refactoring
eshoguli bec1b28
Merge branch 'main' into eshogulin/pass_manager
eshoguli 253c14d
linter fix after merge commit
eshoguli 85d808e
NPUGraph compilation (fp16) & NPU Piecewise Graph tests
eshoguli 11074d9
TorchAir compilation support: refactoring 2
eshoguli 51ac4b4
Merge branch 'main' into eshogulin/pass_manager
ping1jing2 e06675b
CompilationConfig comments fix + linter fix
eshoguli 0c09c24
backend instantiation in get_compiler_backend
eshoguli 00a0b9b
Merge branch 'main' into eshogulin/pass_manager
eshoguli 0b31746
Merge remote-tracking branch 'sglang/main' into eshogulin/pass_manage…
eshoguli 3b5c83b
Merge branch 'main' into eshogulin/pass_manager
eshoguli 7eefeee
linter fix
eshoguli 8c63980
dynamo patch removing
eshoguli 2e02568
fix on main branch: compilation
eshoguli 966bbf4
Merge branch 'main' into eshogulin/pass_manager
eshoguli 14092b3
auto merge fix
eshoguli f989147
tests suit update
eshoguli bf1251d
Add npu_add_rms_norm_dynamic_quant fuse
OrangeRedeng 317174b
Merge branch 'eshogulin/pass_manager' of https://github.com/eshoguli/…
OrangeRedeng e6eb29c
NPU Graph compilation: attention architecture check
eshoguli caba95e
Add npu_add_rms_norm_dynamic_quant fuse: quick fix
eshoguli 3f87879
Qwen3 MoE compilation support for NPU
eshoguli File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| # Copyright 2025 SGLang Team | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
eshoguli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| from typing import List, Optional | ||
|
|
||
| import torch | ||
|
|
||
| import sglang.srt.layers.dp_attention | ||
eshoguli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| @torch.library.custom_op("sglang::_set_dp_buffer_len", mutates_args=()) | ||
| def _set_dp_buffer_len( | ||
| global_dp_buffer_len: Optional[int], | ||
| num_tokens: Optional[int], | ||
| is_max_len: bool, | ||
| global_num_tokens: Optional[List[int]] = None, | ||
| ) -> None: | ||
| global set_dp_buffer_len_original | ||
| sglang.srt.layers.dp_attention.set_dp_buffer_len( | ||
| global_dp_buffer_len, num_tokens, is_max_len, global_num_tokens | ||
| ) | ||
|
|
||
|
|
||
| @_set_dp_buffer_len.register_fake | ||
| def _set_dp_buffer_len_fake( | ||
| global_dp_buffer_len: Optional[int], | ||
| num_tokens: Optional[int], | ||
| is_max_len: bool, | ||
| global_num_tokens: Optional[List[int]] = None, | ||
| ) -> None: | ||
| pass | ||
|
|
||
|
|
||
| @torch.library.custom_op("sglang::_set_is_extend_in_batch", mutates_args=()) | ||
| def _set_is_extend_in_batch(is_extend_in_batch: bool) -> None: | ||
| sglang.srt.layers.dp_attention.set_is_extend_in_batch(is_extend_in_batch) | ||
|
|
||
|
|
||
| @_set_is_extend_in_batch.register_fake | ||
| def _set_is_extend_in_batch_fake(is_extend_in_batch: bool) -> None: | ||
| pass | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| # Copyright 2023-2025 SGLang Team | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| import torch_npu | ||
|
|
||
|
|
||
| class CompilationContext: | ||
| graph_memory_pool = None | ||
| stream: torch_npu.npu.Stream = None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| # Copyright 2025 SGLang Team | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| import torch | ||
|
|
||
| from sglang.srt.compilation.npu.npu_graph_compiler_backend import ( | ||
| NpuGraphCompilerBackend, | ||
| ) | ||
|
|
||
|
|
||
| class NpuGraphCompiler: | ||
| def __init__(self, model: torch.nn.Module, model_type: torch.dtype): | ||
| torch._dynamo.reset() | ||
|
|
||
| self.backend = NpuGraphCompilerBackend(model_type) | ||
| self.compiled_callable = torch.compile( | ||
| model, fullgraph=True, dynamic=False, backend=self.backend | ||
| ) |
46 changes: 46 additions & 0 deletions
46
python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| # Copyright 2025 SGLang Team | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| from typing import Callable | ||
|
|
||
| import torch | ||
| from torch._dynamo.eval_frame import DisableContext | ||
|
|
||
| from sglang.srt.compilation.npu.pass_manager import PassManager | ||
| from sglang.srt.compilation.npu.passes.w8a8_int8 import ( | ||
| DivFuse, | ||
| EraseCopy, | ||
| NpuAddRmsNormQuantFuse, | ||
| ) | ||
|
|
||
|
|
||
| class NpuGraphCompilerBackend: | ||
| def __init__(self, model_type: torch.dtype): | ||
| self.model_type = model_type | ||
|
|
||
| def __call__(self, graph: torch.fx.GraphModule, example_inputs) -> Callable: | ||
| DisableContext.compiled_function_args[DisableContext.batch_size] = ( | ||
| example_inputs | ||
| ) | ||
| if self.model_type == torch.bfloat16: | ||
| NpuGraphCompilerBackend.apply_passes(graph) | ||
| return graph | ||
|
|
||
| def apply_passes(graph_module: torch.fx.GraphModule): | ||
| passManager = PassManager(graph_module) | ||
| passManager.add(NpuAddRmsNormQuantFuse) | ||
| passManager.add(DivFuse) | ||
| passManager.add(EraseCopy) | ||
| passManager.apply() | ||
| graph_module.recompile() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| # Copyright 2025 SGLang Team | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| class PassManager: | ||
| def __init__(self, graph_module: torch.fx.GraphModule): | ||
| self.graph_module = graph_module | ||
| self.passes = [] | ||
|
|
||
| def add(self, pass_): | ||
| self.passes.append(pass_) | ||
|
|
||
| def apply(self): | ||
| updated = False | ||
| for pass_ in self.passes: | ||
| pass_instance = pass_() | ||
| results = [] | ||
| try: | ||
| if callable(pass_instance): | ||
| results = pass_instance(self.graph_module) | ||
| else: | ||
| results = torch.fx.replace_pattern( | ||
| self.graph_module, pass_.pattern, pass_.replacement | ||
| ) | ||
| except: | ||
| # pass was not applied | ||
| pass | ||
|
|
||
| if not updated: | ||
| updated = len(results) != 0 | ||
|
|
||
| if updated: | ||
| self.graph_module.recompile() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| # Copyright 2025 SGLang Team | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| class DivFuse: | ||
| def pattern(x): | ||
| y = 1.0 / x | ||
| z = 1.0 / y | ||
| return z | ||
|
|
||
| def replacement(x): | ||
| return x | ||
|
|
||
|
|
||
| class EraseCopy: | ||
| def __call__(self, graph_module: torch.fx.GraphModule): | ||
| copy_node = None | ||
| prepare_weight_cache_default_node = None | ||
|
|
||
| results = [] | ||
| for module in graph_module.modules(): | ||
| for node in list(module.graph.nodes): | ||
| if node.type == torch.nn.parameter.Parameter: | ||
| continue | ||
| if node.target == "copy_": | ||
| copy_node = node | ||
| prepare_weight_cache_default_node = None | ||
| continue | ||
|
|
||
| if ( | ||
| copy_node | ||
| and node.target == torch.ops.sglang.prepare_weight_cache.default | ||
| ): | ||
| prepare_weight_cache_default_node = node | ||
| continue | ||
|
|
||
| if copy_node and node.target == torch.ops.npu.npu_add_rms_norm_quant: | ||
| arg = copy_node.args[1] | ||
|
|
||
| if prepare_weight_cache_default_node is not None: | ||
| prepare_weight_cache_default_node.args = ( | ||
| arg, | ||
| prepare_weight_cache_default_node.args[1], | ||
| ) | ||
|
|
||
| node.args = ( | ||
| node.args[0], | ||
| arg, | ||
| node.args[2], | ||
| node.args[3], | ||
| node.args[4], | ||
| ) | ||
|
|
||
| module.graph.erase_node(copy_node) | ||
|
|
||
| result = ( | ||
| arg, | ||
| copy_node, | ||
| prepare_weight_cache_default_node, | ||
| ) | ||
| results.append(result) | ||
|
|
||
| copy_node = None | ||
| prepare_weight_cache_default_node = None | ||
|
|
||
| return results | ||
|
|
||
|
|
||
| class NpuAddRmsNormQuantFuse: | ||
| def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset, v1, v2, v3): | ||
| output = torch.ops.npu.npu_add_rms_norm( | ||
| rms_norm_input, residual, rms_norm_weight, 1e-6 | ||
| ) | ||
| out0 = output[0] | ||
| out2 = output[2] | ||
| quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, v1, v2, v3) | ||
| return quantized_output, out2 | ||
|
|
||
| def replacement( | ||
| rms_norm_input, residual, rms_norm_weight, scale, offset, v1, v2, v3 | ||
| ): | ||
| output = torch.ops.npu.npu_add_rms_norm_quant( | ||
eshoguli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| rms_norm_input, residual, rms_norm_weight, 1.0 / scale, offset, epsilon=1e-6 | ||
| ) | ||
| quantized_output = output[0] | ||
| out2 = output[2] | ||
| return quantized_output, out2 | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.