Skip to content

Conversation

@eshoguli
Copy link

@eshoguli eshoguli commented Sep 30, 2025

Motivation

  1. Performance gain (Ascend 910B3, batch size = 128):
Branch Median ITL (ms) Performance gain
Reference 74.27
Compilation (--enable-torch-compile) 71.64 3.5%
Piecewise Graph (--enable-piecewise-npu-graph-decode) 73.20 1.4%
  1. Support model compilation on NPU and PassManager for current and future fuses in Python via torch.fx.replace_pattern. Fuses can be easily developed by external contributors.
  2. Improve performance via fuse AddRmsNorm and AscendQuantV2 kernels to AddRmsNormQuant kernel:
  3. Encrease performance for compiled model via NPU kerneal and torch guards avoiding.
  4. Piecewise graph execution approach
  5. TorchAir compilation backend support
    Original comment: [feat] npu support enable_torch_compile #12371

TorchAir (Torch Ascend Intermediate Representation) is an extension library that provides graph mode capabilities for torch_npu. It enables users to perform graph-mode inference on NPU using PyTorch and torch_npu. TorchAir externally offers a torch.compile backend for NPU, which interfaces with torch._dynamo. Through the following features, performance optimization and capability enhancement of the torch fx graph can be achieved.

torchair1

TorchAir Main Features:

  1. Basic Features:
  • Enable NPU kernels that depend on host-value tiling operators (e.g., FIA) to support npugraph
  • Graph input copy optimization
  • Memory reuse across multi-graphs
  1. FX Pass:
  • In-place optimization
  • Redundant operator elimination
  • NPU fused operator passes
  1. Advanced Features:
  • Static shape kernel compilation
  • Multi-stream within single graphs
  • Compilation caching

How to enable compilation and fuses for NPUGraph decode:

--enable-torch-compile

How to enable piecewise graph and fuses for decode:

--enable-piecewise-npu-graph-decode

How to enable TorchAir for decode:

--enable-torch-compile --disable-cuda-graph

CANN version: 8.2
Torch NPU version: torch-npu 2.6.0.post3

Modifications

  1. Model compilation support by torch.compile
    Use --enable-torch-compile to enable compilation and optional --torch-compile-max-bs argument to limit max batch size for compilation.

  2. NpuGraphCompilerBackend compilation backend for NPU Graph capturing. Implemented in: python/sglang/srt/model_executor/compilation/npu_graph_compiler_backend.py, usage:

self.compiled_callable = torch.compile(
    model, fullgraph=True, dynamic=False, backend=NpuGraphCompilerBackend()
)
  1. PiecewiseNpuGraphCompilerBackend compilation backend for Piecewise graph and partial NPU Graph capturing. Inherited from NpuGraphCompilerBackend to reuse fusing passes. Implemented in: python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py, usage:
self.compiled_callable = torch.compile(
    model, fullgraph=True, dynamic=False, backend=PiecewiseNpuGraphCompilerBackend()
)

You can use --enable-piecewise-npu-graph-decode to enable Piecewise Graph.
Optional command line arguments:

  • --compilation-config {"splitting_ops": ["atb._npu_paged_attention"]} to configure compilation backend,
  • --cuda-graph-bs to specify batch size,
  • --cuda-graph-max-bs to limit max batch size.
  1. PassManager passes manager and passes python/sglang/srt/model_executor/compilation/passes/w8a8_int8 to optimize model during compilation. Usage:
from sglang.srt.compilation.npu.pass_manager import PassManager
from sglang.srt.compilation.npu.passes.w8a8_int8 import (
    DivFuse,
    EraseCopy,
    NpuAddRmsNormQuantFuse,
    NpuAddRmsNormDynamicQuantFuse,
)

def apply_passes(graph_module: torch.fx.GraphModule):
    passManager = PassManager(graph_module)
    passManager.add(NpuAddRmsNormQuantFuse)
    passManager.add(NpuAddRmsNormDynamicQuantFuse)
    passManager.add(DivFuse)
    passManager.add(EraseCopy)
    passManager.apply()
    graph_module.recompile()
  1. RotaryEmbedding layer use NPU kernel in forward instead native implementation
  2. torch.compile guards are ignored to improve forward performance
  3. Ascend page attention is used to enable compilation without custom ops: python/sglang/srt/layers/attention/ascend_backend.py
  4. TorchAir
    7.1. Rewrite the capture function;
    7.2. Encapsulate the kvcache input (input needs all kvcache);
    7.3. Pad the block table to the max length;
    7.4. TorchAir input preparation;

The calling process is as follows.
torchair2

Class Diagram

classDiagram
    class PiecewiseNpuGraphRunnerDecode
    class NPUCompileModelRunner
    class NPUGraphRunner
    class CudaGraphRunner
    class NpuGraphCompiler
    class NpuGraphCompilerBackend
    class PiecewiseNpuGraphCompiler
    class PiecewiseNpuGraphCompilerBackend

    NPUGraphRunner--|>CudaGraphRunner
    NPUGraphRunner-->NpuGraphCompiler
    NpuGraphCompiler-->NpuGraphCompilerBackend
    NPUCompileModelRunner-->CudaGraphRunner
    PiecewiseNpuGraphRunnerDecode-->CudaGraphRunner
    PiecewiseNpuGraphRunnerDecode-->PiecewiseNpuGraphCompiler
    PiecewiseNpuGraphCompiler-->PiecewiseNpuGraphCompilerBackend
    PiecewiseNpuGraphCompilerBackend--|>NpuGraphCompilerBackend
Loading

Accuracy Tests

Collected on gsm8k dataset for static quantized Qwen3-32B:

Version Accuracy
Reference 85.7%
Compilation 85.6%
Piecewise Graph 85.7%
TorchAir 85.1%

TorchAir

python3 few_shot_gsm8k.py --data-path "/path/to/model/test.jsonl.txt” --parallel 32 --num-questions 200

Accuracy: 0.865
Invalid: 0.000
Latency: 43.077 s
Output throughput: 795.877 token/s

Collected on MMMU dataset for Qwen3-VL-30B-A3B-Instruct:

Version Overall accuracy
Reference 0.592
Compilation 0.597
Piecewise Graph 0.591

Benchmarking and Profiling (910B3)

Reference

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 128
Successful requests:                     128
Benchmark duration (s):                  119.11
Total input tokens:                      131072
Total generated tokens:                  131072
Total generated tokens (retokenized):    131061
Request throughput (req/s):              1.07
Input token throughput (tok/s):          1100.41
Output token throughput (tok/s):         1100.41
Total token throughput (tok/s):          2200.82
Concurrency:                             109.65
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   102033.93
Median E2E Latency (ms):                 100067.18
---------------Time to First Token----------------
Mean TTFT (ms):                          13474.46
Median TTFT (ms):                        13730.29
P99 TTFT (ms):                           24113.16
---------------Inter-Token Latency----------------
Mean ITL (ms):                           86.57
Median ITL (ms):                         74.27
P95 ITL (ms):                            79.96
P99 ITL (ms):                            80.59
Max ITL (ms):                            25360.72
==================================================

Compilation

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 128
Successful requests:                     128
Benchmark duration (s):                  117.06
Total input tokens:                      131072
Total input text tokens:                 131072
Total input vision tokens:               0
Total generated tokens:                  131072
Total generated tokens (retokenized):    131064
Request throughput (req/s):              1.09
Input token throughput (tok/s):          1119.68
Output token throughput (tok/s):         1119.68
Total token throughput (tok/s):          2239.35
Concurrency:                             108.96
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   99646.08
Median E2E Latency (ms):                 97652.90
---------------Time to First Token----------------
Mean TTFT (ms):                          13575.07
Median TTFT (ms):                        13454.43
P99 TTFT (ms):                           24318.40
---------------Inter-Token Latency----------------
Mean ITL (ms):                           84.14
Median ITL (ms):                         71.64
P95 ITL (ms):                            76.49
P99 ITL (ms):                            78.27
Max ITL (ms):                            24386.78
==================================================

Piecewise Graph

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 128
Successful requests:                     128
Benchmark duration (s):                  125.24
Total input tokens:                      131072
Total generated tokens:                  131072
Total generated tokens (retokenized):    131067
Request throughput (req/s):              1.02
Input token throughput (tok/s):          1046.58
Output token throughput (tok/s):         1046.58
Total token throughput (tok/s):          2093.17
Concurrency:                             103.59
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   101352.11
Median E2E Latency (ms):                 98694.90
---------------Time to First Token----------------
Mean TTFT (ms):                          13580.41
Median TTFT (ms):                        14449.29
P99 TTFT (ms):                           24292.08
---------------Inter-Token Latency----------------
Mean ITL (ms):                           85.80
Median ITL (ms):                         73.20
P95 ITL (ms):                            78.72
P99 ITL (ms):                            79.48
Max ITL (ms):                            25003.23
==================================================

Future roadmaps

In the torch_npu 7.2.0 version, the reduce-overhead mode of the torchair backend will support torch.compile(model, dynamic=True). This mode will be set as the default in get_compile_backend(), enabling support for methods wrapped by the @torch.compile() decorator.
In the torch_npu 7.3.0 version, the capture and replay of NPUGraph currently integrated in the torchair backend will be changed to optional execution. The torchair backend will only perform optimizations such as fx pass optimization and static kernel compilation, while the capture and replay of NPUGraph will be implemented independently. This design is closer to the implementation of CudaGraphRunner, decoupling fx graph optimization from graph offloading.

Checklist

@eshoguli eshoguli changed the title [WIP] NPU Graph Compilation & PassManager NPU Graph Compilation support and PassManager with AddRmsNorm & Quantize fuse Oct 30, 2025
@eshoguli eshoguli force-pushed the eshogulin/pass_manager branch 9 times, most recently from 508e483 to d77e709 Compare October 30, 2025 22:40
@eshoguli eshoguli force-pushed the eshogulin/pass_manager branch from c958827 to b974460 Compare October 31, 2025 15:38
@eshoguli eshoguli force-pushed the eshogulin/pass_manager branch from 8150d72 to 11074d9 Compare November 20, 2025 08:07
Copy link
Contributor

@ssshinigami ssshinigami left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ping1jing2 ping1jing2 self-assigned this Nov 20, 2025
@eshoguli eshoguli force-pushed the eshogulin/pass_manager branch from e6942bc to e06675b Compare November 21, 2025 09:18
@yuan-luo yuan-luo self-requested a review November 27, 2025 02:36
@eshoguli eshoguli requested a review from hebiao064 as a code owner November 27, 2025 13:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants