1- # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/backend .py
1+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/backends .py
22
33
44import ast
1515import torch .fx as fx
1616from torch ._dispatch .python import enable_python_dispatcher
1717
18- from sglang .srt .compilation .compilation_config import CompilationConfig
1918from sglang .srt .compilation .compilation_counter import compilation_counter
2019from sglang .srt .compilation .compiler_interface import EagerAdapter , InductorAdaptor
2120from sglang .srt .compilation .cuda_piecewise_backend import CUDAPiecewiseBackend
21+ from sglang .srt .compilation .inductor_pass import InductorPass
2222from sglang .srt .compilation .pass_manager import PostGradPassManager
23+ from sglang .srt .configs .compilation_config import CompilationConfig
24+ from sglang .srt .configs .sglang_config import SGLangConfig
2325from sglang .srt .utils .common import rank0_log
2426
2527logger = logging .getLogger (__name__ )
@@ -114,6 +116,7 @@ def compile(
114116 graph : fx .GraphModule ,
115117 example_inputs ,
116118 inductor_config : dict [str , Any ],
119+ compilation_config : CompilationConfig ,
117120 graph_index : int = 0 ,
118121 num_graphs : int = 1 ,
119122 runtime_shape : Optional [int ] = None ,
@@ -127,7 +130,29 @@ def compile(
127130
128131 compiled_graph = None
129132
130- # TODO(Yuwei): support cache loading
133+ # try to load from the cache
134+ compiled_graph = self .load (graph , example_inputs , graph_index , runtime_shape )
135+ if compiled_graph is not None :
136+ if graph_index == num_graphs - 1 :
137+ # after loading the last graph for this shape, record the time.
138+ # there can be multiple graphs due to piecewise compilation.
139+ now = time .time ()
140+ elapsed = now - compilation_start_time
141+ compilation_config .compilation_time += elapsed
142+ if runtime_shape is None :
143+ logger .info (
144+ "Directly load the compiled graph(s) for dynamic shape "
145+ "from the cache, took %.3f s" ,
146+ elapsed ,
147+ )
148+ else :
149+ logger .info (
150+ "Directly load the compiled graph(s) for shape %s "
151+ "from the cache, took %.3f s" ,
152+ str (runtime_shape ),
153+ elapsed ,
154+ )
155+ return compiled_graph
131156
132157 # no compiler cached the graph, or the cache is disabled,
133158 # we need to compile it
@@ -174,6 +199,7 @@ def compile(
174199 if graph_index == num_graphs - 1 :
175200 now = time .time ()
176201 elapsed = now - compilation_start_time
202+ compilation_config .compilation_time += elapsed
177203 if runtime_shape is None :
178204 logger .info ("Compiling a graph for dynamic shape takes %.2f s" , elapsed )
179205 else :
@@ -240,20 +266,27 @@ def split_graph(
240266 return split_gm , outputs
241267
242268
243- # we share the global graph pool among all the backends
244- global_graph_pool = None
245-
246269compilation_start_time = 0.0
247270
248271
249272class PiecewiseCompileInterpreter (torch .fx .Interpreter ):
273+ """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
274+ It runs the given graph with fake inputs, and compile some
275+ submodules specified by `compile_submod_names` with the given
276+ compilation configs.
277+
278+ NOTE: the order in `compile_submod_names` matters, because
279+ it will be used to determine the order of the compiled piecewise
280+ graphs. The first graph will handle logging, and the last graph
281+ has some special cudagraph output handling.
282+ """
283+
250284 def __init__ (
251285 self ,
252286 module : torch .fx .GraphModule ,
253287 compile_submod_names : list [str ],
254- inductor_config : dict [str , Any ],
255288 graph_pool ,
256- compile_config : CompilationConfig ,
289+ sglang_config : SGLangConfig ,
257290 sglang_backend : "SGLangBackend" ,
258291 ):
259292 super ().__init__ (module )
@@ -265,8 +298,8 @@ def __init__(
265298 self .sglang_backend = sglang_backend
266299 # When True, it annoyingly dumps the torch.fx.Graph on errors.
267300 self .extra_traceback = False
268- self .inductor_config = inductor_config
269- self .compile_config = compile_config
301+ self .sglang_config = sglang_config
302+ self .compilation_config = sglang_config . compile_config
270303
271304 def run (self , * args ):
272305 fake_args = [
@@ -297,6 +330,7 @@ def call_module(
297330 submod ,
298331 args ,
299332 self .inductor_config ,
333+ self .compilation_config ,
300334 graph_index = index ,
301335 num_graphs = len (self .compile_submod_names ),
302336 runtime_shape = None ,
@@ -305,7 +339,7 @@ def call_module(
305339
306340 self .module .__dict__ [target ] = CUDAPiecewiseBackend (
307341 submod ,
308- self .compile_config ,
342+ self .compilation_config ,
309343 self .inductor_config ,
310344 self .graph_pool ,
311345 index ,
@@ -339,7 +373,19 @@ def set_model_tag(tag: str):
339373
340374
341375class SGLangBackend :
376+ """The compilation backend for `torch.compile` with SGLang.
377+ It is used for compilation mode of `CompilationMode.SGLANG_COMPILE`,
378+ where we customize the compilation.
379+
380+ The major work of this backend is to split the graph into
381+ piecewise graphs, and pass them to the piecewise backend.
382+
383+ This backend also adds the PostGradPassManager to Inductor config,
384+ which handles the post-grad passes.
385+ """
342386
387+ sglang_config : SGLangConfig
388+ compilation_config : CompilationConfig
343389 graph_pool : Any
344390 _called : bool = False
345391 # the graph we compiled
@@ -356,7 +402,7 @@ class SGLangBackend:
356402
357403 def __init__ (
358404 self ,
359- config : CompilationConfig ,
405+ sglang_config : SGLangConfig ,
360406 graph_pool : Any ,
361407 ):
362408 rank0_log (f"Initializing SGLangBackend" )
@@ -367,15 +413,31 @@ def __init__(
367413 self .sym_tensor_indices = []
368414 self .input_buffers = []
369415
370- self .compiler_manager = CompilerManager (config )
416+ self .sglang_config = sglang_config
417+ self .compilation_config = sglang_config .compilation_config
418+
419+ self .compiler_manager = CompilerManager (self .compilation_config )
371420 self .inductor_config = {
372421 "enable_auto_functionalized_v2" : False ,
373422 }
374- self .compile_config = config
375423
376424 def configure_post_pass (self ):
377- self .post_grad_pass_manager .configure ()
378- self .inductor_config ["post_grad_custom_post_pass" ] = self .post_grad_pass_manager
425+ config = self .compilation_config
426+ self .post_grad_pass_manager .configure (self .sglang_config )
427+
428+ # Post-grad custom passes are run using the post_grad_custom_post_pass
429+ # hook. If a pass for that hook exists, add it to the pass manager.
430+ inductor_config = config .inductor_compile_config
431+ PASS_KEY = "post_grad_custom_post_pass"
432+ if PASS_KEY in inductor_config :
433+ if isinstance (inductor_config [PASS_KEY ], PostGradPassManager ):
434+ # PassManager already added to config
435+ pass
436+ else :
437+ # Config should automatically wrap all inductor passes
438+ assert isinstance (inductor_config [PASS_KEY ], InductorPass )
439+ self .post_grad_pass_manager .add (inductor_config [PASS_KEY ])
440+ inductor_config [PASS_KEY ] = self .post_grad_pass_manager
379441
380442 def __call__ (self , graph : fx .GraphModule , example_inputs ) -> Callable :
381443 rank0_log (f"SGLangBackend __call__" )
@@ -427,9 +489,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
427489 PiecewiseCompileInterpreter (
428490 self .split_gm ,
429491 submod_names_to_compile ,
430- self .inductor_config ,
431492 self .graph_pool ,
432- self .compile_config ,
493+ self .sglang_config ,
433494 self ,
434495 ).run (* example_inputs )
435496
0 commit comments