Skip to content

Commit 2baee2a

Browse files
author
luoyuan.luo
committed
Revise compilation and pass_manager
1 parent 5450a94 commit 2baee2a

File tree

6 files changed

+48
-62
lines changed

6 files changed

+48
-62
lines changed

python/sglang/srt/compilation/backend.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,21 +89,24 @@ def load(
8989
graph_index: int,
9090
runtime_shape: Optional[int] = None,
9191
) -> Optional[Callable]:
92-
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
92+
key = (runtime_shape, graph_index, self.compiler.name)
93+
handle = self.cache.get(key, None)
94+
if handle is None:
95+
return None
96+
9397
compiled_graph = self.compiler.load(
9498
handle, graph, example_inputs, graph_index, runtime_shape
9599
)
96100
if runtime_shape is None:
97101
logger.debug(
98-
"Directly load the %s-th graph for dynamic shape from %s via "
99-
"handle %s",
102+
"Directly load the %s-th graph for dynamic shape from %s via handle %s",
100103
graph_index,
101104
self.compiler.name,
102105
handle,
103106
)
104107
else:
105108
logger.debug(
106-
"Directly load the %s-th graph for shape %s from %s via " "handle %s",
109+
"Directly load the %s-th graph for shape %s from %s via handle %s",
107110
graph_index,
108111
str(runtime_shape),
109112
self.compiler.name,
@@ -299,7 +302,7 @@ def __init__(
299302
# When True, it annoyingly dumps the torch.fx.Graph on errors.
300303
self.extra_traceback = False
301304
self.sglang_config = sglang_config
302-
self.compilation_config = sglang_config.compile_config
305+
self.compilation_config = sglang_config.compilation_config
303306

304307
def run(self, *args):
305308
fake_args = [
@@ -329,7 +332,7 @@ def call_module(
329332
self.sglang_backend.compiler_manager.compile(
330333
submod,
331334
args,
332-
self.inductor_config,
335+
self.sglang_backend.inductor_config,
333336
self.compilation_config,
334337
graph_index=index,
335338
num_graphs=len(self.compile_submod_names),
@@ -340,7 +343,7 @@ def call_module(
340343
self.module.__dict__[target] = CUDAPiecewiseBackend(
341344
submod,
342345
self.compilation_config,
343-
self.inductor_config,
346+
self.sglang_backend.inductor_config,
344347
self.graph_pool,
345348
index,
346349
len(self.compile_submod_names),

python/sglang/srt/compilation/fix_functionalization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch._higher_order_ops.auto_functionalize import auto_functionalized
1010

1111
from sglang.srt.compilation.fx_utils import is_func
12-
from sglang.srt.compilation.inductor_pass import SGLangInductorPass
12+
from sglang.srt.compilation.sglang_inductor_pass import SGLangInductorPass
1313

1414
logger = logging.getLogger(__name__)
1515

python/sglang/srt/compilation/inductor_pass.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,12 @@
55
import inspect
66
import json
77
import logging
8-
import time
98
import types
109
from contextlib import contextmanager
1110
from typing import Any, Callable, Optional, Union
1211

1312
import torch
1413
from torch import fx
15-
from torch._dynamo.utils import lazy_format_graph_code
1614
from torch._inductor.custom_graph_pass import CustomGraphPass
1715
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
1816

@@ -113,35 +111,6 @@ def uuid(self) -> Any:
113111
return self._uuid
114112

115113

116-
class SGLangInductorPass(InductorPass):
117-
118-
def __init__(
119-
self,
120-
):
121-
self.pass_name = self.__class__.__name__
122-
123-
def dump_graph(self, graph: torch.fx.Graph, stage: str):
124-
lazy_format_graph_code(stage, graph.owning_module)
125-
126-
def begin(self):
127-
self._start_time = time.perf_counter_ns()
128-
129-
def end_and_log(self):
130-
self._end_time = time.perf_counter_ns()
131-
duration_ms = float(self._end_time - self._start_time) / 1.0e6
132-
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
133-
134-
135-
class PrinterInductorPass(SGLangInductorPass):
136-
137-
def __init__(self, name: str):
138-
super().__init__()
139-
self.name = name
140-
141-
def __call__(self, graph: torch.fx.Graph):
142-
self.dump_graph(graph, self.name)
143-
144-
145114
def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]:
146115
"""
147116
Applies a FakeTensorMode context. This is useful when you don't want to

python/sglang/srt/compilation/pass_manager.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from sglang.srt.compilation.inductor_pass import (
1010
CustomGraphPass,
1111
InductorPass,
12-
SGLangInductorPass,
1312
get_pass_context,
1413
)
14+
from sglang.srt.compilation.sglang_inductor_pass import SGLangInductorPass
1515
from sglang.srt.configs.sglang_config import SGLangConfig, set_current_sglang_config
1616

1717
logger = logging.getLogger(__name__)
@@ -45,12 +45,12 @@ def __call__(self, graph: fx.Graph):
4545
self.fix_functionalization(graph)
4646

4747
def configure(self, config: SGLangConfig):
48-
# TODO(yuan-luo): PassConfig
49-
self.pass_config = dict()
50-
self.fix_functionalization = FixFunctionalizationPass()
48+
self.pass_config = config.compilation_config.pass_config
5149

5250
with set_current_sglang_config(config, check_compile=False):
53-
self.passes += [AllReduceFusionPass(config)]
51+
if self.pass_config.enable_fi_allreduce_fusion:
52+
self.passes += [AllReduceFusionPass(config)]
53+
self.fix_functionalization = FixFunctionalizationPass(config)
5454

5555
def add(self, pass_: InductorPass):
5656
assert isinstance(pass_, InductorPass)

python/sglang/srt/configs/compilation_config.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class CompilationMode:
2525
shape specialization, and custom passes."""
2626

2727

28+
@dataclass
2829
class PassConfig:
2930
"""Configuration for custom Inductor passes.
3031
This is separate from general `CompilationConfig` so that inductor passes
@@ -69,9 +70,20 @@ class CompilationConfig:
6970
certain small batchsizes, where inductor is good at optimizing.
7071
"""
7172

73+
# Sizes to capture cudagraph.
74+
# - None (default): capture sizes are inferred from sglang config.
75+
# - list[int]: capture sizes are specified as given.
76+
capture_sizes: List[int]
77+
78+
compiler: str = "eager"
79+
80+
enable_debug_mode: bool = False
81+
7282
# Top-level Compilation control
7383
level: Optional[int] = None
7484

85+
mode: CompilationMode | None = None
86+
7587
# The backend for compilation. It needs to be a string:
7688
# (empty string): use the default backend ("inductor" on CUDA-alike
7789
# platforms).
@@ -82,32 +94,32 @@ class CompilationConfig:
8294
# Inductor capture
8395
use_inductor: bool = True
8496

97+
"""Fine-grained control over which custom ops to enable/disable. Use 'all'
98+
to enable all, 'none' to disable all. Also specify a list of custom op
99+
names to enable (prefixed with a '+'), or disable (prefixed with a '-').
100+
Examples:
101+
102+
- 'all,-op1' to enable all except op1
103+
- 'none,+op1,+op2' to enable only op1 and op2
104+
105+
By default, all custom ops are enabled when running without Inductor and
106+
disabled when running with Inductor: mode>=SGLANG_COMPILE and backend="inductor".
107+
Inductor generates (fused) Triton kernels for disabled custom ops."""
108+
splitting_ops: list[str] | None = None
109+
110+
use_inductor_graph_partition: bool = False
111+
85112
inductor_compile_config: dict = field(default_factory=dict)
86113

87114
inductor_passes: dict[str, str] = field(default_factory=dict)
88115

89-
# Sizes to capture cudagraph.
90-
# - None (default): capture sizes are inferred from sglang config.
91-
# - list[int]: capture sizes are specified as given.
92-
cudagraph_capture_sizes: list[int] | None = None
93-
94116
pass_config: PassConfig = field(default_factory=PassConfig)
95117

96118
# time taken for compilation
97119
compilation_time: float = field(default=0.0, init=False)
98120

99-
compiler: str = ""
100-
101-
def __init__(
102-
self,
103-
capture_sizes: List[int],
104-
compiler: str = "eager",
105-
enable_debug_mode: bool = False,
106-
):
107-
self.traced_files = set()
108-
self.capture_sizes = capture_sizes
109-
self.compiler = compiler
110-
self.enable_debug_mode = enable_debug_mode
111-
112121
def get_capture_sizes(self):
113122
return self.capture_sizes
123+
124+
def get_enable_debug_mode(self):
125+
return self.enable_debug_mode

python/sglang/srt/configs/sglang_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
22
import logging
3+
from contextlib import contextmanager
34
from dataclasses import replace
45
from functools import lru_cache
56

@@ -47,6 +48,7 @@ def with_hf_config(
4748
_current_prefix: str | None = None
4849

4950

51+
@contextmanager
5052
def set_current_sglang_config(
5153
sglang_config: SGLangConfig, check_compile=False, prefix: str | None = None
5254
):

0 commit comments

Comments
 (0)