-
Notifications
You must be signed in to change notification settings - Fork 290
support Deepseek3.2 #1103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
support Deepseek3.2 #1103
Changes from 7 commits
003247d
03aeae6
f172db9
b9f23de
ee7100c
425edb2
d1a773d
eb4b957
33d81f3
3b3a204
2ba2189
c0f072b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -50,7 +50,7 @@ def __init__(self, model): | |||||
| self.softmax_scale = self.softmax_scale * mscale * mscale | ||||||
|
|
||||||
|
|
||||||
| @ModelRegistry(["deepseek_v2", "deepseek_v3"]) | ||||||
| @ModelRegistry(["deepseek_v2", "deepseek_v3", "deepseek_v32"]) | ||||||
|
||||||
| @ModelRegistry(["deepseek_v2", "deepseek_v3", "deepseek_v32"]) | |
| @ModelRegistry(["deepseek_v2", "deepseek_v3"]) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| import torch | ||
| from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo | ||
| from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager | ||
|
|
||
| class Deepseek3_2FlashAttentionStateInfo(Deepseek2FlashAttentionStateInfo): | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
| self.lengths = None | ||
| self.page_table_size_1 = None | ||
| self.ks = None | ||
| self.ke = None | ||
| self.nsa_cu_seqlens_k = None | ||
| self.index_topk = 2048 | ||
| return | ||
|
|
||
| def init_some_extra_state(self, model, input_ids: torch.Tensor): | ||
| super().init_some_extra_state(model, input_ids) | ||
| assert isinstance(self.mem_manager, Deepseek3_2MemoryManager) | ||
| self.indexer_ks_mem_manager = self.mem_manager.indexer_ks_mem_manager | ||
|
|
||
| # Ensure b_ready_cache_len is set for both prefill and decode modes | ||
| if self.is_prefill: | ||
| # b_ready_cache_len is already set in basemodel.py for prefill | ||
| pass | ||
| else: | ||
| # In decode mode, b_ready_cache_len should be b_seq_len - b_q_seq_len | ||
| # since b_q_seq_len represents the new tokens being processed | ||
| if self.b_ready_cache_len is None: | ||
| self.b_ready_cache_len = self.b_seq_len - self.b_q_seq_len | ||
|
|
||
| self.nsa_cache_seqlens = self.b_att_seq_len.clamp(max=self.index_topk) | ||
| assert self.nsa_cache_seqlens.dtype == torch.int32 | ||
| self.nsa_cu_seqlens_k = torch.nn.functional.pad( | ||
| torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32), (1, 0) | ||
| ) | ||
|
|
||
| # Pre-compute NSA indexer indexing structures | ||
| self._init_nsa_indexing_structures() | ||
|
|
||
| def _init_nsa_indexing_structures(self): | ||
| """Pre-compute ks, ke, lengths, and page_table_size_1 for NSA indexer""" | ||
| mem_index_list = [] | ||
| ks_list = [] | ||
| ke_list = [] | ||
| lengths_list = [] | ||
| offset = 0 | ||
| num_seq_len = self.b_req_idx.shape[0] | ||
| self.page_table_size_1 = torch.zeros((num_seq_len, self.b_seq_len.max()), dtype=torch.int, device='cuda') | ||
|
|
||
| for i in range(num_seq_len): | ||
| seq_len = self.b_seq_len[i] | ||
| q_seq_len = self.b_q_seq_len[i] | ||
| mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] | ||
| mem_index_list.append(mem_index) | ||
| self.page_table_size_1[i, :seq_len] = mem_index | ||
| ks = torch.zeros(q_seq_len, dtype=torch.int, device='cuda') + offset | ||
| ke = torch.arange(q_seq_len, dtype=torch.int, device='cuda') + offset + 1 | ||
| ks_list.append(ks) | ||
| ke_list.append(ke) | ||
| lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda')) | ||
| offset += seq_len | ||
|
|
||
| self.mem_index = torch.cat(mem_index_list, dim=0) | ||
| # ks : [seq_len_q] 标志kv的起始位置 | ||
| # ke : [seq_len_q] 标志kv的结束位置 | ||
sufubao marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.ks = torch.cat(ks_list, dim=0) | ||
| self.ke = torch.cat(ke_list, dim=0) | ||
| self.lengths = torch.cat(lengths_list, dim=0) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| from sgl_kernel import fast_topk_transform_fused | ||
| import deep_gemm | ||
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| from lightllm.common.basemodel.layer_infer.base_layer_infer import BaseLayerInfer | ||
| from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight | ||
| from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo | ||
| from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd | ||
| from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant | ||
| from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager | ||
| from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks | ||
| from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks | ||
| from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward | ||
| from lightllm.utils.log_utils import init_logger | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| class NSAIndexerInfer(BaseLayerInfer): | ||
| def __init__(self, layer_idx, network_config, mode=[]): | ||
| super().__init__() | ||
| self.layer_idx_ = layer_idx | ||
| self.network_config_ = network_config | ||
| self.mode = mode | ||
| self.index_topk = network_config["index_topk"] | ||
| self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_ | ||
| self.tp_k_head_num_ = 1 | ||
| self.tp_v_head_num_ = 1 | ||
| self.qk_nope_head_dim = network_config["qk_nope_head_dim"] | ||
| self.qk_rope_head_dim = network_config["qk_rope_head_dim"] | ||
| self.index_head_dim = network_config["index_head_dim"] | ||
| self.eps = network_config["rms_norm_eps"] | ||
| self.block_size = network_config["quantization_config"]["weight_block_size"][0] | ||
| self.scale_fmt = network_config["quantization_config"]["scale_fmt"] | ||
| self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) | ||
| self.index_n_heads = network_config["index_n_heads"] | ||
| self.index_n_heads_scale = (self.index_n_heads ** -0.5) * self.softmax_scale | ||
|
|
||
| return | ||
|
|
||
| def ref_fp8_mqa_logits(self, q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, | ||
| cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, cost_only: bool = False): | ||
| seq_len_kv = kv.shape[0] | ||
|
|
||
| if cost_only: | ||
| start = cu_seqlen_ks.clamp(min=0, max=seq_len_kv) | ||
| end = cu_seqlen_ke.clamp(min=0, max=seq_len_kv) | ||
| count_ones_per_row = (end - start).clamp(min=0) | ||
| return count_ones_per_row.sum() | ||
|
|
||
| k = kv | ||
| q = q.float() | ||
| k = k.float() | ||
|
|
||
| mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None] | ||
| mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None] | ||
| mask = mask_lo & mask_hi | ||
|
|
||
| score = torch.einsum('mhd,nd->hmn', q, k) | ||
| logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) | ||
| logits = logits.masked_fill(~mask, float('-inf')) | ||
|
|
||
| cost = mask.sum() | ||
| return logits, cost | ||
|
|
||
| def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, | ||
| infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight) -> torch.Tensor: | ||
|
|
||
| q, k = self._get_q_k_bf16(hidden_states, q_lora, infer_state, layer_weight) | ||
| q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) | ||
| k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) | ||
|
|
||
| destindex_copy_indexer_ks( | ||
| k_fp8.unsqueeze(1), | ||
| k_scale.unsqueeze(1), | ||
| infer_state.mem_index, | ||
| infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] | ||
| ) | ||
|
|
||
| weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale | ||
| weights = weights.unsqueeze(-1) * q_scale | ||
|
|
||
| mem_index = infer_state.mem_index | ||
| ks = infer_state.ks | ||
| ke = infer_state.ke | ||
| lengths = infer_state.lengths | ||
| page_table_1 = infer_state.page_table_size_1 | ||
|
|
||
| k_fp8_, k_scale_ = extract_indexer_ks(infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_], mem_index) | ||
|
|
||
| logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) | ||
|
|
||
| # 返回 : [seq_q_len, topk] 无效的位置使用-1填充 | ||
|
||
| return fast_topk_transform_fused( | ||
| score=logits, # [seq_len_q, seq_len_kv] | ||
| lengths=lengths, # [seq_len_q] | ||
| page_table_size_1=page_table_1, # [seq_len_q, max(lengths)] 无效的使用0填充 | ||
| cu_seqlens_q=infer_state.cu_seqlens_q, # [seq_len_q + 1] | ||
| topk=self.index_topk, | ||
| ) | ||
|
|
||
|
|
||
| def get_k_float32_from_buffer(self, buffer: torch.Tensor): | ||
| k_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn) | ||
| k_scale = buffer[:, :, 128:].view(torch.float32)[:, :, :1] | ||
| k_float32 = k_fp8.float() * k_scale | ||
| return k_float32 | ||
|
|
||
| @staticmethod | ||
| def _rotate_activation(x: torch.Tensor) -> torch.Tensor: | ||
| assert x.dtype == torch.bfloat16 | ||
| from sgl_kernel import hadamard_transform | ||
|
|
||
| hidden_size = x.size(-1) | ||
| assert ( | ||
| hidden_size & (hidden_size - 1) | ||
| ) == 0, "Hidden size must be a power of 2 for Hadamard transform." | ||
| return hadamard_transform(x, scale=hidden_size**-0.5) | ||
|
|
||
| def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, | ||
| infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight): | ||
| q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) | ||
| k = layer_weight.wk_proj_.mm(hidden_states) | ||
|
|
||
| k = layernorm_forward(k, layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps) | ||
|
|
||
| rotary_emb_fwd( | ||
| q[:, :, : self.qk_rope_head_dim], | ||
| k[:, None, : self.qk_rope_head_dim], | ||
| infer_state.position_cos, | ||
| infer_state.position_sin, | ||
| ) | ||
|
|
||
| q = self._rotate_activation(q) | ||
| k = self._rotate_activation(k) | ||
| return q, k | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call to
_autotune_warmup()has been commented out. If this was for debugging, it should be removed. If autotuning is intentionally disabled for this model, it would be better to control this with a configuration flag for clarity and to avoid accidental performance degradation.