Skip to content

Commit 10a9b66

Browse files
authored
fix: fix accuracy bug in flashinfer and fa3 kernel. (#995)
1 parent 250318a commit 10a9b66

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd
2424
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
2525
from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo
26+
from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo
2627
from functools import partial
2728
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
2829
from lightllm.distributed.communication_op import all_gather, all_gather_into_tensor, all_reduce, reduce_scatter_tensor
@@ -302,7 +303,7 @@ def _context_attention_flashattention_kernel_with_CC(
302303
self,
303304
q: torch.Tensor,
304305
kv,
305-
infer_state: Deepseek2FlashInferStateInfo,
306+
infer_state: Deepseek2FlashAttentionStateInfo,
306307
layer_weight: Deepseek2TransformerLayerWeight,
307308
out=None,
308309
) -> torch.Tensor:
@@ -323,7 +324,7 @@ def _context_attention_flashattention_kernel_with_CC(
323324
k=k.view(-1, self.tp_k_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim),
324325
v=v.view(-1, self.tp_v_head_num_, self.v_head_dim),
325326
cu_seqlens_q=infer_state.cu_seqlens_q,
326-
cu_seqlens_k=infer_state.cu_seqlens_k,
327+
cu_seqlens_k=infer_state.cu_seqlens_q,
327328
max_seqlen_q=infer_state.q_max_seq_len,
328329
max_seqlen_k=infer_state.max_seq_len,
329330
softmax_scale=self.softmax_scale,
@@ -547,7 +548,7 @@ def _context_attention_kernel_origin_fp8(
547548
return o_tensor
548549

549550
def _token_gqa_decode_attention_flashattention(
550-
self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
551+
self, q, infer_state: Deepseek2FlashAttentionStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
551552
):
552553
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
553554
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)

lightllm/models/llama/flashinfer_struct.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
8181
self.req_manager.req_to_token_indexs,
8282
self.b_req_idx,
8383
self.b_seq_len,
84-
kv_starts,
85-
self.max_len_in_batch,
84+
kv_starts[:-1],
85+
self.max_kv_seq_len,
8686
kv_indices,
8787
)
8888
self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(

0 commit comments

Comments
 (0)