Skip to content

Commit ee7100c

Browse files
committed
need fix
1 parent b9f23de commit ee7100c

File tree

4 files changed

+16
-11
lines changed

4 files changed

+16
-11
lines changed

lightllm/models/deepseek3_2/__init__.py

Whitespace-only changes.

lightllm/models/deepseek3_2/infer_struct.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ def __init__(self):
99
self.page_table_size_1 = None
1010
self.ks = None
1111
self.ke = None
12-
13-
self.topk_indices = None
12+
self.nsa_cu_seqlens_k = None
13+
self.index_topk = 2048
1414
return
1515

1616
def init_some_extra_state(self, model, input_ids: torch.Tensor):
@@ -24,3 +24,9 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
2424
# since b_q_seq_len represents the new tokens being processed
2525
if self.b_ready_cache_len is None:
2626
self.b_ready_cache_len = self.b_seq_len - self.b_q_seq_len
27+
28+
self.nsa_cache_seqlens = self.b_att_seq_len.clamp(max=model.index_topk)
29+
assert self.nsa_cache_seqlens.dtype == torch.int32
30+
self.nsa_cu_seqlens_k = torch.nn.functional.pad(
31+
torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32), (1, 0)
32+
)

lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def _nsa_context_attention_kernel(
8686
mla_out, _, _ = flash_mla_sparse_fwd(
8787
q=q_all,
8888
kv=infer_state.mem_manager.kv_buffer[self.layer_num_],
89-
indices=self.topk_indices,
89+
indices=self.topk_indices.unsqueeze(1),
9090
sm_scale=self.softmax_scale,
9191
d_v=self.kv_lora_rank,
9292
)
@@ -100,23 +100,17 @@ def _nsa_token_attention_kernel(
100100
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
101101
k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim)
102102
kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank)
103-
k_descale, v_descale = None, None
104103
o_tensor = flash_attn_with_kvcache(
105104
q=q_rope,
106105
k_cache=k_rope,
107106
v_cache=kv_nope,
108107
qv=q_nope,
109108
page_table=self.topk_indices,
110-
cache_seqlens=infer_state.b_att_seq_len,
109+
cache_seqlens=infer_state.nsa_cache_seqlens,
111110
cu_seqlens_q=infer_state.cu_seqlens_q,
112-
cu_seqlens_k_new=infer_state.cu_seqlens_k,
111+
cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k,
113112
max_seqlen_q=infer_state.max_q_seq_len,
114113
softmax_scale=self.softmax_scale,
115114
causal=True,
116-
window_size=(-1, -1),
117-
softcap=0.0,
118-
k_descale=k_descale,
119-
v_descale=v_descale,
120-
return_softmax_lse=False,
121115
)
122116
return o_tensor

lightllm/models/deepseek3_2/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ class Deepseek3_2TpPartModel(Deepseek2TpPartModel):
1616
# infer state class
1717
infer_state_class = Deepseek3_2FlashAttentionStateInfo
1818

19+
def __init__(self, kvargs):
20+
super().__init__(kvargs)
21+
self.index_topk = self.config["index_topk"]
22+
return
23+
1924
def _init_mem_manager(self):
2025
manager_class = Deepseek3_2MemoryManager
2126
if "triton_fp8kv" in self.mode:

0 commit comments

Comments
 (0)