Skip to content

Commit afcb9d5

Browse files
committed
fix fp8 accuracy#13832
1 parent 8400e3f commit afcb9d5

File tree

1 file changed

+65
-13
lines changed

1 file changed

+65
-13
lines changed

python/sglang/srt/layers/attention/flashmla_backend.py

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
1515
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
1616
from sglang.srt.layers.dp_attention import get_attention_tp_size
17+
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
1718
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
1819

1920
if TYPE_CHECKING:
@@ -75,7 +76,10 @@ def __init__(
7576
self.data_type = model_runner.kv_cache_dtype
7677
self.q_data_type = model_runner.dtype
7778
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
78-
self.is_fp8_kvcache = model_runner.kv_cache_dtype.startswith("fp8")
79+
self.is_fp8_kvcache = model_runner.server_args.kv_cache_dtype in {
80+
"fp8_e4m3",
81+
"fp8_e5m2",
82+
}
7983

8084
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
8185

@@ -105,7 +109,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
105109
forward_batch.seq_lens.to(torch.int32),
106110
self.num_q_heads,
107111
1,
108-
self.is_fp8_kvcache
112+
is_fp8_kvcache=self.is_fp8_kvcache,
109113
)
110114
self.forward_metadata = FlashMLADecodeMetadata(
111115
mla_metadata,
@@ -136,6 +140,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
136140
seq_lens.to(torch.int32),
137141
self.num_draft_tokens * self.num_q_heads,
138142
1,
143+
is_fp8_kvcache=self.is_fp8_kvcache,
139144
)
140145

141146
# Use FlashMLADecodeMetadata which has the attributes forward_extend expects
@@ -170,6 +175,7 @@ def init_cuda_graph_state(
170175
),
171176
self.num_draft_tokens * self.num_q_heads,
172177
1,
178+
is_fp8_kvcache=self.is_fp8_kvcache,
173179
)
174180
else:
175181
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
@@ -178,6 +184,7 @@ def init_cuda_graph_state(
178184
),
179185
self.num_q_heads,
180186
1,
187+
is_fp8_kvcache=self.is_fp8_kvcache,
181188
)
182189
self.cuda_graph_kv_indices = cuda_graph_kv_indices
183190

@@ -208,6 +215,7 @@ def init_forward_metadata_capture_cuda_graph(
208215
seq_lens.to(torch.int32),
209216
num_q_heads,
210217
1,
218+
is_fp8_kvcache=self.is_fp8_kvcache,
211219
)
212220
self.cuda_graph_mla_metadata.copy_(mla_metadata)
213221
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
@@ -233,6 +241,7 @@ def init_forward_metadata_capture_cuda_graph(
233241
seq_lens.to(torch.int32),
234242
self.num_draft_tokens * self.num_q_heads,
235243
1,
244+
is_fp8_kvcache=self.is_fp8_kvcache,
236245
)
237246
self.cuda_graph_mla_metadata.copy_(mla_metadata)
238247
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
@@ -283,6 +292,7 @@ def init_forward_metadata_replay_cuda_graph(
283292
seq_lens.to(torch.int32),
284293
num_q_heads,
285294
1,
295+
is_fp8_kvcache=self.is_fp8_kvcache,
286296
)
287297
self.cuda_graph_mla_metadata.copy_(mla_metadata)
288298
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
@@ -308,6 +318,7 @@ def init_forward_metadata_replay_cuda_graph(
308318
seq_lens.to(torch.int32),
309319
self.num_draft_tokens * self.num_q_heads,
310320
1,
321+
is_fp8_kvcache=self.is_fp8_kvcache,
311322
)
312323
self.cuda_graph_mla_metadata.copy_(mla_metadata)
313324
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
@@ -356,7 +367,29 @@ def forward_decode(
356367

357368
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
358369
if self.data_type == torch.float8_e4m3fn:
359-
reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn)
370+
# For FP8 KV cache, Q needs to be converted to FP8 for FlashMLA kernel
371+
# Reference: https://github.com/vllm-project/vllm/pull/22668
372+
# In SGLang, we use layer.k_scale for both q and k scales (similar to vLLM where _q_scale defaults to k_scale)
373+
if layer.k_scale is not None:
374+
q_scale = layer.k_scale
375+
descale_q = layer.k_scale.reshape(1)
376+
descale_k = layer.k_scale.reshape(1)
377+
else:
378+
# Fallback to 1.0 if k_scale is not initialized
379+
q_scale = torch.ones((1,), dtype=torch.float32, device=reshape_q.device)
380+
descale_q = torch.ones(
381+
(1,), dtype=torch.float32, device=reshape_q.device
382+
)
383+
descale_k = torch.ones(
384+
(1,), dtype=torch.float32, device=reshape_q.device
385+
)
386+
387+
# Quantize Q using scaled_fp8_quant (matching vLLM's approach)
388+
# Reshape to 2D for scaled_fp8_quant (which requires 2D input)
389+
q_shape = reshape_q.shape
390+
reshape_q_2d = reshape_q.reshape(-1, q_shape[-1])
391+
reshape_q_fp8_2d, _ = scaled_fp8_quant(reshape_q_2d, q_scale)
392+
reshape_q_fp8 = reshape_q_fp8_2d.reshape(q_shape)
360393
o, _ = flash_mla_with_kvcache(
361394
q=reshape_q_fp8,
362395
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
@@ -367,8 +400,8 @@ def forward_decode(
367400
num_splits=self.forward_metadata.num_splits,
368401
softmax_scale=layer.scaling,
369402
causal=True,
370-
descale_q=torch.ones((1), dtype=torch.float32, device=reshape_q.device),
371-
descale_k=torch.ones((1), dtype=torch.float32, device=reshape_q.device),
403+
descale_q=descale_q,
404+
descale_k=descale_k,
372405
)
373406

374407
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
@@ -414,8 +447,31 @@ def forward_extend(
414447
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
415448

416449
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
417-
if self.data_type == torch.float8_e4m3fn:
418-
reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn)
450+
if self.is_fp8_kvcache:
451+
# For FP8 KV cache, Q needs to be converted to FP8 for FlashMLA kernel
452+
# In SGLang, we use layer.k_scale for both q and k scales (similar to vLLM where _q_scale defaults to k_scale)
453+
if layer.k_scale is not None:
454+
q_scale = layer.k_scale
455+
descale_q = layer.k_scale.reshape(1)
456+
descale_k = layer.k_scale.reshape(1)
457+
else:
458+
# Fallback to 1.0 if k_scale is not initialized
459+
q_scale = torch.ones(
460+
(1,), dtype=torch.float32, device=reshape_q.device
461+
)
462+
descale_q = torch.ones(
463+
(1,), dtype=torch.float32, device=reshape_q.device
464+
)
465+
descale_k = torch.ones(
466+
(1,), dtype=torch.float32, device=reshape_q.device
467+
)
468+
469+
# Quantize Q using scaled_fp8_quant (matching vLLM's approach)
470+
# Reshape to 2D for scaled_fp8_quant (which requires 2D input)
471+
q_shape = reshape_q.shape
472+
reshape_q_2d = reshape_q.reshape(-1, q_shape[-1])
473+
reshape_q_fp8_2d, _ = scaled_fp8_quant(reshape_q_2d, q_scale)
474+
reshape_q_fp8 = reshape_q_fp8_2d.reshape(q_shape)
419475
o, _ = flash_mla_with_kvcache(
420476
q=reshape_q_fp8,
421477
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
@@ -427,12 +483,8 @@ def forward_extend(
427483
num_splits=self.forward_metadata.num_splits,
428484
softmax_scale=layer.scaling,
429485
causal=True,
430-
descale_q=torch.ones(
431-
(1), dtype=torch.float32, device=reshape_q.device
432-
),
433-
descale_k=torch.ones(
434-
(1), dtype=torch.float32, device=reshape_q.device
435-
),
486+
descale_q=descale_q,
487+
descale_k=descale_k,
436488
)
437489
else:
438490
o, _ = flash_mla_with_kvcache(

0 commit comments

Comments
 (0)