Skip to content

Commit 5450a94

Browse files
luoyuan.luoyuan-luo
andcommitted
Support fusion pass and refactor pass manager
Co-authored-by: Yuan Luo <[email protected]>
1 parent 45cf575 commit 5450a94

14 files changed

+1028
-56
lines changed

python/sglang/srt/compilation/backend.py

Lines changed: 79 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/backend.py
1+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/backends.py
22

33

44
import ast
@@ -15,11 +15,13 @@
1515
import torch.fx as fx
1616
from torch._dispatch.python import enable_python_dispatcher
1717

18-
from sglang.srt.compilation.compilation_config import CompilationConfig
1918
from sglang.srt.compilation.compilation_counter import compilation_counter
2019
from sglang.srt.compilation.compiler_interface import EagerAdapter, InductorAdaptor
2120
from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend
21+
from sglang.srt.compilation.inductor_pass import InductorPass
2222
from sglang.srt.compilation.pass_manager import PostGradPassManager
23+
from sglang.srt.configs.compilation_config import CompilationConfig
24+
from sglang.srt.configs.sglang_config import SGLangConfig
2325
from sglang.srt.utils.common import rank0_log
2426

2527
logger = logging.getLogger(__name__)
@@ -114,6 +116,7 @@ def compile(
114116
graph: fx.GraphModule,
115117
example_inputs,
116118
inductor_config: dict[str, Any],
119+
compilation_config: CompilationConfig,
117120
graph_index: int = 0,
118121
num_graphs: int = 1,
119122
runtime_shape: Optional[int] = None,
@@ -127,7 +130,29 @@ def compile(
127130

128131
compiled_graph = None
129132

130-
# TODO(Yuwei): support cache loading
133+
# try to load from the cache
134+
compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape)
135+
if compiled_graph is not None:
136+
if graph_index == num_graphs - 1:
137+
# after loading the last graph for this shape, record the time.
138+
# there can be multiple graphs due to piecewise compilation.
139+
now = time.time()
140+
elapsed = now - compilation_start_time
141+
compilation_config.compilation_time += elapsed
142+
if runtime_shape is None:
143+
logger.info(
144+
"Directly load the compiled graph(s) for dynamic shape "
145+
"from the cache, took %.3f s",
146+
elapsed,
147+
)
148+
else:
149+
logger.info(
150+
"Directly load the compiled graph(s) for shape %s "
151+
"from the cache, took %.3f s",
152+
str(runtime_shape),
153+
elapsed,
154+
)
155+
return compiled_graph
131156

132157
# no compiler cached the graph, or the cache is disabled,
133158
# we need to compile it
@@ -174,6 +199,7 @@ def compile(
174199
if graph_index == num_graphs - 1:
175200
now = time.time()
176201
elapsed = now - compilation_start_time
202+
compilation_config.compilation_time += elapsed
177203
if runtime_shape is None:
178204
logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed)
179205
else:
@@ -240,20 +266,27 @@ def split_graph(
240266
return split_gm, outputs
241267

242268

243-
# we share the global graph pool among all the backends
244-
global_graph_pool = None
245-
246269
compilation_start_time = 0.0
247270

248271

249272
class PiecewiseCompileInterpreter(torch.fx.Interpreter):
273+
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
274+
It runs the given graph with fake inputs, and compile some
275+
submodules specified by `compile_submod_names` with the given
276+
compilation configs.
277+
278+
NOTE: the order in `compile_submod_names` matters, because
279+
it will be used to determine the order of the compiled piecewise
280+
graphs. The first graph will handle logging, and the last graph
281+
has some special cudagraph output handling.
282+
"""
283+
250284
def __init__(
251285
self,
252286
module: torch.fx.GraphModule,
253287
compile_submod_names: list[str],
254-
inductor_config: dict[str, Any],
255288
graph_pool,
256-
compile_config: CompilationConfig,
289+
sglang_config: SGLangConfig,
257290
sglang_backend: "SGLangBackend",
258291
):
259292
super().__init__(module)
@@ -265,8 +298,8 @@ def __init__(
265298
self.sglang_backend = sglang_backend
266299
# When True, it annoyingly dumps the torch.fx.Graph on errors.
267300
self.extra_traceback = False
268-
self.inductor_config = inductor_config
269-
self.compile_config = compile_config
301+
self.sglang_config = sglang_config
302+
self.compilation_config = sglang_config.compile_config
270303

271304
def run(self, *args):
272305
fake_args = [
@@ -297,6 +330,7 @@ def call_module(
297330
submod,
298331
args,
299332
self.inductor_config,
333+
self.compilation_config,
300334
graph_index=index,
301335
num_graphs=len(self.compile_submod_names),
302336
runtime_shape=None,
@@ -305,7 +339,7 @@ def call_module(
305339

306340
self.module.__dict__[target] = CUDAPiecewiseBackend(
307341
submod,
308-
self.compile_config,
342+
self.compilation_config,
309343
self.inductor_config,
310344
self.graph_pool,
311345
index,
@@ -339,7 +373,19 @@ def set_model_tag(tag: str):
339373

340374

341375
class SGLangBackend:
376+
"""The compilation backend for `torch.compile` with SGLang.
377+
It is used for compilation mode of `CompilationMode.SGLANG_COMPILE`,
378+
where we customize the compilation.
379+
380+
The major work of this backend is to split the graph into
381+
piecewise graphs, and pass them to the piecewise backend.
382+
383+
This backend also adds the PostGradPassManager to Inductor config,
384+
which handles the post-grad passes.
385+
"""
342386

387+
sglang_config: SGLangConfig
388+
compilation_config: CompilationConfig
343389
graph_pool: Any
344390
_called: bool = False
345391
# the graph we compiled
@@ -356,7 +402,7 @@ class SGLangBackend:
356402

357403
def __init__(
358404
self,
359-
config: CompilationConfig,
405+
sglang_config: SGLangConfig,
360406
graph_pool: Any,
361407
):
362408
rank0_log(f"Initializing SGLangBackend")
@@ -367,15 +413,31 @@ def __init__(
367413
self.sym_tensor_indices = []
368414
self.input_buffers = []
369415

370-
self.compiler_manager = CompilerManager(config)
416+
self.sglang_config = sglang_config
417+
self.compilation_config = sglang_config.compilation_config
418+
419+
self.compiler_manager = CompilerManager(self.compilation_config)
371420
self.inductor_config = {
372421
"enable_auto_functionalized_v2": False,
373422
}
374-
self.compile_config = config
375423

376424
def configure_post_pass(self):
377-
self.post_grad_pass_manager.configure()
378-
self.inductor_config["post_grad_custom_post_pass"] = self.post_grad_pass_manager
425+
config = self.compilation_config
426+
self.post_grad_pass_manager.configure(self.sglang_config)
427+
428+
# Post-grad custom passes are run using the post_grad_custom_post_pass
429+
# hook. If a pass for that hook exists, add it to the pass manager.
430+
inductor_config = config.inductor_compile_config
431+
PASS_KEY = "post_grad_custom_post_pass"
432+
if PASS_KEY in inductor_config:
433+
if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
434+
# PassManager already added to config
435+
pass
436+
else:
437+
# Config should automatically wrap all inductor passes
438+
assert isinstance(inductor_config[PASS_KEY], InductorPass)
439+
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
440+
inductor_config[PASS_KEY] = self.post_grad_pass_manager
379441

380442
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
381443
rank0_log(f"SGLangBackend __call__")
@@ -427,9 +489,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
427489
PiecewiseCompileInterpreter(
428490
self.split_gm,
429491
submod_names_to_compile,
430-
self.inductor_config,
431492
self.graph_pool,
432-
self.compile_config,
493+
self.sglang_config,
433494
self,
434495
).run(*example_inputs)
435496

0 commit comments

Comments
 (0)