2323from lightllm .models .deepseek2 .triton_kernel .rotary_emb import rotary_emb_fwd
2424from lightllm .models .deepseek2 .infer_struct import Deepseek2InferStateInfo
2525from lightllm .models .deepseek2 .flashinfer_struct import Deepseek2FlashInferStateInfo
26+ from lightllm .models .deepseek2 .flashattention_infer_struct import Deepseek2FlashAttentionStateInfo
2627from functools import partial
2728from lightllm .models .llama .yarn_rotary_utils import get_deepseek_mscale
2829from lightllm .distributed .communication_op import all_gather , all_gather_into_tensor , all_reduce , reduce_scatter_tensor
@@ -302,7 +303,7 @@ def _context_attention_flashattention_kernel_with_CC(
302303 self ,
303304 q : torch .Tensor ,
304305 kv ,
305- infer_state : Deepseek2FlashInferStateInfo ,
306+ infer_state : Deepseek2FlashAttentionStateInfo ,
306307 layer_weight : Deepseek2TransformerLayerWeight ,
307308 out = None ,
308309 ) -> torch .Tensor :
@@ -323,7 +324,7 @@ def _context_attention_flashattention_kernel_with_CC(
323324 k = k .view (- 1 , self .tp_k_head_num_ , self .qk_nope_head_dim + self .qk_rope_head_dim ),
324325 v = v .view (- 1 , self .tp_v_head_num_ , self .v_head_dim ),
325326 cu_seqlens_q = infer_state .cu_seqlens_q ,
326- cu_seqlens_k = infer_state .cu_seqlens_k ,
327+ cu_seqlens_k = infer_state .cu_seqlens_q ,
327328 max_seqlen_q = infer_state .q_max_seq_len ,
328329 max_seqlen_k = infer_state .max_seq_len ,
329330 softmax_scale = self .softmax_scale ,
@@ -547,7 +548,7 @@ def _context_attention_kernel_origin_fp8(
547548 return o_tensor
548549
549550 def _token_gqa_decode_attention_flashattention (
550- self , q , infer_state : Deepseek2FlashInferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
551+ self , q , infer_state : Deepseek2FlashAttentionStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
551552 ):
552553 q_nope , q_rope = q [:, :, : - self .qk_rope_head_dim ], q [:, :, - self .qk_rope_head_dim :]
553554 q_nope = layer_weight .k_b_proj_ .bmm (q_nope .transpose (0 , 1 )).transpose (0 , 1 )
0 commit comments