1414from sglang .srt .layers .attention .flashinfer_mla_backend import FlashInferMLAAttnBackend
1515from sglang .srt .layers .attention .utils import create_flashmla_kv_indices_triton
1616from sglang .srt .layers .dp_attention import get_attention_tp_size
17+ from sglang .srt .layers .quantization .fp8_kernel import scaled_fp8_quant
1718from sglang .srt .model_executor .forward_batch_info import ForwardBatch , ForwardMode
1819
1920if TYPE_CHECKING :
@@ -75,7 +76,10 @@ def __init__(
7576 self .data_type = model_runner .kv_cache_dtype
7677 self .q_data_type = model_runner .dtype
7778 self .kv_cache_dim = self .kv_lora_rank + self .qk_rope_head_dim
78- self .is_fp8_kvcache = model_runner .kv_cache_dtype .startswith ("fp8" )
79+ self .is_fp8_kvcache = model_runner .server_args .kv_cache_dtype in {
80+ "fp8_e4m3" ,
81+ "fp8_e5m2" ,
82+ }
7983
8084 self .num_draft_tokens = model_runner .server_args .speculative_num_draft_tokens
8185
@@ -105,7 +109,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
105109 forward_batch .seq_lens .to (torch .int32 ),
106110 self .num_q_heads ,
107111 1 ,
108- self .is_fp8_kvcache
112+ is_fp8_kvcache = self .is_fp8_kvcache ,
109113 )
110114 self .forward_metadata = FlashMLADecodeMetadata (
111115 mla_metadata ,
@@ -136,6 +140,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
136140 seq_lens .to (torch .int32 ),
137141 self .num_draft_tokens * self .num_q_heads ,
138142 1 ,
143+ is_fp8_kvcache = self .is_fp8_kvcache ,
139144 )
140145
141146 # Use FlashMLADecodeMetadata which has the attributes forward_extend expects
@@ -170,6 +175,7 @@ def init_cuda_graph_state(
170175 ),
171176 self .num_draft_tokens * self .num_q_heads ,
172177 1 ,
178+ is_fp8_kvcache = self .is_fp8_kvcache ,
173179 )
174180 else :
175181 self .cuda_graph_mla_metadata , self .cuda_graph_num_splits = get_mla_metadata (
@@ -178,6 +184,7 @@ def init_cuda_graph_state(
178184 ),
179185 self .num_q_heads ,
180186 1 ,
187+ is_fp8_kvcache = self .is_fp8_kvcache ,
181188 )
182189 self .cuda_graph_kv_indices = cuda_graph_kv_indices
183190
@@ -208,6 +215,7 @@ def init_forward_metadata_capture_cuda_graph(
208215 seq_lens .to (torch .int32 ),
209216 num_q_heads ,
210217 1 ,
218+ is_fp8_kvcache = self .is_fp8_kvcache ,
211219 )
212220 self .cuda_graph_mla_metadata .copy_ (mla_metadata )
213221 self .cuda_graph_num_splits [: bs + 1 ].copy_ (num_splits )
@@ -233,6 +241,7 @@ def init_forward_metadata_capture_cuda_graph(
233241 seq_lens .to (torch .int32 ),
234242 self .num_draft_tokens * self .num_q_heads ,
235243 1 ,
244+ is_fp8_kvcache = self .is_fp8_kvcache ,
236245 )
237246 self .cuda_graph_mla_metadata .copy_ (mla_metadata )
238247 self .cuda_graph_num_splits [: bs + 1 ].copy_ (num_splits )
@@ -283,6 +292,7 @@ def init_forward_metadata_replay_cuda_graph(
283292 seq_lens .to (torch .int32 ),
284293 num_q_heads ,
285294 1 ,
295+ is_fp8_kvcache = self .is_fp8_kvcache ,
286296 )
287297 self .cuda_graph_mla_metadata .copy_ (mla_metadata )
288298 self .cuda_graph_num_splits [: bs + 1 ].copy_ (num_splits )
@@ -308,6 +318,7 @@ def init_forward_metadata_replay_cuda_graph(
308318 seq_lens .to (torch .int32 ),
309319 self .num_draft_tokens * self .num_q_heads ,
310320 1 ,
321+ is_fp8_kvcache = self .is_fp8_kvcache ,
311322 )
312323 self .cuda_graph_mla_metadata .copy_ (mla_metadata )
313324 self .cuda_graph_num_splits [: bs + 1 ].copy_ (num_splits )
@@ -356,7 +367,29 @@ def forward_decode(
356367
357368 reshape_q = q .view (bs , - 1 , layer .tp_q_head_num , layer .head_dim )
358369 if self .data_type == torch .float8_e4m3fn :
359- reshape_q_fp8 = reshape_q .to (torch .float8_e4m3fn )
370+ # For FP8 KV cache, Q needs to be converted to FP8 for FlashMLA kernel
371+ # Reference: https://github.com/vllm-project/vllm/pull/22668
372+ # In SGLang, we use layer.k_scale for both q and k scales (similar to vLLM where _q_scale defaults to k_scale)
373+ if layer .k_scale is not None :
374+ q_scale = layer .k_scale
375+ descale_q = layer .k_scale .reshape (1 )
376+ descale_k = layer .k_scale .reshape (1 )
377+ else :
378+ # Fallback to 1.0 if k_scale is not initialized
379+ q_scale = torch .ones ((1 ,), dtype = torch .float32 , device = reshape_q .device )
380+ descale_q = torch .ones (
381+ (1 ,), dtype = torch .float32 , device = reshape_q .device
382+ )
383+ descale_k = torch .ones (
384+ (1 ,), dtype = torch .float32 , device = reshape_q .device
385+ )
386+
387+ # Quantize Q using scaled_fp8_quant (matching vLLM's approach)
388+ # Reshape to 2D for scaled_fp8_quant (which requires 2D input)
389+ q_shape = reshape_q .shape
390+ reshape_q_2d = reshape_q .reshape (- 1 , q_shape [- 1 ])
391+ reshape_q_fp8_2d , _ = scaled_fp8_quant (reshape_q_2d , q_scale )
392+ reshape_q_fp8 = reshape_q_fp8_2d .reshape (q_shape )
360393 o , _ = flash_mla_with_kvcache (
361394 q = reshape_q_fp8 ,
362395 k_cache = k_cache .view (- 1 , PAGE_SIZE , 1 , self .kv_cache_dim ),
@@ -367,8 +400,8 @@ def forward_decode(
367400 num_splits = self .forward_metadata .num_splits ,
368401 softmax_scale = layer .scaling ,
369402 causal = True ,
370- descale_q = torch . ones (( 1 ), dtype = torch . float32 , device = reshape_q . device ) ,
371- descale_k = torch . ones (( 1 ), dtype = torch . float32 , device = reshape_q . device ) ,
403+ descale_q = descale_q ,
404+ descale_k = descale_k ,
372405 )
373406
374407 return o .view (- 1 , layer .tp_q_head_num * layer .v_head_dim )
@@ -414,8 +447,31 @@ def forward_extend(
414447 k_cache = forward_batch .token_to_kv_pool .get_key_buffer (layer .layer_id )
415448
416449 reshape_q = q .view (bs , - 1 , layer .tp_q_head_num , layer .head_dim )
417- if self .data_type == torch .float8_e4m3fn :
418- reshape_q_fp8 = reshape_q .to (torch .float8_e4m3fn )
450+ if self .is_fp8_kvcache :
451+ # For FP8 KV cache, Q needs to be converted to FP8 for FlashMLA kernel
452+ # In SGLang, we use layer.k_scale for both q and k scales (similar to vLLM where _q_scale defaults to k_scale)
453+ if layer .k_scale is not None :
454+ q_scale = layer .k_scale
455+ descale_q = layer .k_scale .reshape (1 )
456+ descale_k = layer .k_scale .reshape (1 )
457+ else :
458+ # Fallback to 1.0 if k_scale is not initialized
459+ q_scale = torch .ones (
460+ (1 ,), dtype = torch .float32 , device = reshape_q .device
461+ )
462+ descale_q = torch .ones (
463+ (1 ,), dtype = torch .float32 , device = reshape_q .device
464+ )
465+ descale_k = torch .ones (
466+ (1 ,), dtype = torch .float32 , device = reshape_q .device
467+ )
468+
469+ # Quantize Q using scaled_fp8_quant (matching vLLM's approach)
470+ # Reshape to 2D for scaled_fp8_quant (which requires 2D input)
471+ q_shape = reshape_q .shape
472+ reshape_q_2d = reshape_q .reshape (- 1 , q_shape [- 1 ])
473+ reshape_q_fp8_2d , _ = scaled_fp8_quant (reshape_q_2d , q_scale )
474+ reshape_q_fp8 = reshape_q_fp8_2d .reshape (q_shape )
419475 o , _ = flash_mla_with_kvcache (
420476 q = reshape_q_fp8 ,
421477 k_cache = k_cache .view (- 1 , PAGE_SIZE , 1 , self .kv_cache_dim ),
@@ -427,12 +483,8 @@ def forward_extend(
427483 num_splits = self .forward_metadata .num_splits ,
428484 softmax_scale = layer .scaling ,
429485 causal = True ,
430- descale_q = torch .ones (
431- (1 ), dtype = torch .float32 , device = reshape_q .device
432- ),
433- descale_k = torch .ones (
434- (1 ), dtype = torch .float32 , device = reshape_q .device
435- ),
486+ descale_q = descale_q ,
487+ descale_k = descale_k ,
436488 )
437489 else :
438490 o , _ = flash_mla_with_kvcache (
0 commit comments