From 003247d41fe6bf53d97365fc4a382728f2f223f5 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 5 Nov 2025 03:55:14 +0000 Subject: [PATCH 01/12] support deepseek v3.2 --- .../layer_weights/transformer_layer_weight.py | 11 +- lightllm/common/triton_utils/autotuner.py | 2 +- lightllm/models/deepseek3_2/infer_struct.py | 9 ++ .../layer_infer/nsa_indexer_layer_inder.py | 142 ++++++++++++++++++ .../layer_infer/transformer_layer_infer.py | 127 ++++++++++++++++ .../layer_weights/nsa_indexer_layer_weight.py | 49 ++++++ .../layer_weights/transformer_layer_weight.py | 16 ++ lightllm/models/deepseek3_2/mem_manager.py | 47 ++++++ lightllm/models/deepseek3_2/model.py | 38 +++++ .../deepseek3_2/triton_kernel/__init__.py | 0 .../deepseek3_2/triton_kernel/act_quant.py | 137 +++++++++++++++++ .../triton_kernel/token_group_quant.py | 103 +++++++++++++ 12 files changed, 679 insertions(+), 2 deletions(-) create mode 100644 lightllm/models/deepseek3_2/infer_struct.py create mode 100644 lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py create mode 100644 lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py create mode 100644 lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/deepseek3_2/mem_manager.py create mode 100644 lightllm/models/deepseek3_2/model.py create mode 100644 lightllm/models/deepseek3_2/triton_kernel/__init__.py create mode 100644 lightllm/models/deepseek3_2/triton_kernel/act_quant.py create mode 100644 lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py diff --git a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py index 48167a067..07b3bf69c 100644 --- a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py @@ -36,8 +36,17 @@ def load_hf_weights(self, weights): """ for attr_name in dir(self): attr = getattr(self, attr_name, None) - if isinstance(attr, MultiMMWeightTpl): + if isinstance(attr, TransformerLayerWeight): + attr.load_hf_weights(weights) + elif isinstance(attr, MultiMMWeightTpl): with self.lock: attr.load_hf_weights(weights) elif isinstance(attr, BaseWeight): attr.load_hf_weights(weights) + + def verify_load(self): + for attr_name in dir(self): + attr = getattr(self, attr_name, None) + if isinstance(attr, TransformerLayerWeight): + attr.verify_load() + super().verify_load() \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py index a919f7b28..c62a2572f 100644 --- a/lightllm/common/triton_utils/autotuner.py +++ b/lightllm/common/triton_utils/autotuner.py @@ -62,7 +62,7 @@ def autotune( as needed before invocation. """ - def decorator(fn): + def decorator(fn: Callable) -> Callable: return Autotuner( fn=fn, kernel_name=kernel_name, diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py new file mode 100644 index 000000000..6e5e766b2 --- /dev/null +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -0,0 +1,9 @@ +import os +import torch +import numpy as np +import torch.distributed as dist +from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo + + +class Deepseek3_2FlashAttentionInferStateInfo(Deepseek2FlashAttentionStateInfo): + pass \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py new file mode 100644 index 000000000..a3891f0f3 --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -0,0 +1,142 @@ +from sgl_kernel import fast_topk_transform_fused +import deep_gemm +import torch +import torch.nn.functional as F + +from lightllm.common.basemodel.layer_infer.base_layer_infer import BaseLayerInfer +from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionInferStateInfo +from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant + + +class NSAIndexerInfer(BaseLayerInfer): + def __init__(self, layer_idx, network_config, mode=[]): + super().__init__() + self.layer_idx_ = layer_idx + self.network_config_ = network_config + self.mode = mode + self.index_topk = network_config["index_topk"] + self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_ + self.tp_k_head_num_ = 1 + self.tp_v_head_num_ = 1 + self.qk_nope_head_dim = network_config["qk_nope_head_dim"] + self.qk_rope_head_dim = network_config["qk_rope_head_dim"] + self.index_head_dim = network_config["index_head_dim"] + self.eps = network_config["rms_norm_eps"] + self.block_size = network_config["quantization_config"]["weight_block_size"][0] + self.scale_fmt = network_config["quantization_config"]["scale_fmt"] + self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) + self.index_n_heads = network_config["index_n_heads"] + self.index_n_heads_scale = self.index_n_heads ** -0.5 + + self.q_lora = None + self.hidden_states = None + return + + def ref_fp8_mqa_logits(self, q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, cost_only: bool = False): + seq_len_kv = kv.shape[0] + + if cost_only: + start = cu_seqlen_ks.clamp(min=0, max=seq_len_kv) + end = cu_seqlen_ke.clamp(min=0, max=seq_len_kv) + count_ones_per_row = (end - start).clamp(min=0) + return count_ones_per_row.sum() + + k = kv + q = q.float() + k = k.float() + + mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None] + mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None] + mask = mask_lo & mask_hi + + score = torch.einsum('mhd,nd->hmn', q, k) + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float('-inf')) + + cost = mask.sum() + return logits, cost + + def get_indices(self, infer_state: Deepseek3_2FlashAttentionInferStateInfo, layer_weight: NSAIndexerWeight) -> torch.Tensor: + assert self.hidden_states is not None + assert self.q_lora is not None + + q, k = self._get_q_k_bf16(infer_state, layer_weight) + q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) + k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) + + weights = layer_weight.weights_proj_.mm(self.hidden_states) * self.index_n_heads_scale + weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale + + logits = fp8_paged_mqa_logits_torch( + q_fp8, k_fp8, weights, + infer_state.lengths, + infer_state.page_table, + infer_state.max_model_len + ) + + return fast_topk_transform_fused( + score=logits, + lengths=infer_state.lengths, + page_table_size_1=infer_state.page_table, + cu_seqlens_q=infer_state.b1_cu_q_seq_len, + topk=self.index_topk + ) + + @staticmethod + def _rotate_activation(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.bfloat16 + from sgl_kernel import hadamard_transform + + hidden_size = x.size(-1) + assert ( + hidden_size & (hidden_size - 1) + ) == 0, "Hidden size must be a power of 2 for Hadamard transform." + return hadamard_transform(x, scale=hidden_size**-0.5) + + def _get_q_k_bf16(self, infer_state: Deepseek3_2FlashAttentionInferStateInfo, layer_weight: NSAIndexerWeight): + q = layer_weight.wq_b_proj_.mm(self.q_lora).view(-1, self.index_n_heads, self.index_head_dim) + self.q_lora = None + + k = layer_weight.wk_proj_.mm(self.hidden_states) + self.hidden_states = None + k = F.layer_norm( + k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps + ).type_as(k) + + rotary_emb_fwd( + q[:, :, : self.qk_rope_head_dim], + k[:, None, : self.qk_rope_head_dim], + infer_state.position_cos, + infer_state.position_sin, + ) + + q = self._rotate_activation(q) + k = self._rotate_activation(k) + return q, k + + +# TODO +def fp8_paged_mqa_logits_torch(q: torch.Tensor, kv_cache: torch.Tensor, + weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor, + max_model_len: int): + batch_size, next_n, heads, dim = q.size() + num_block, block_size, _, dim = kv_cache.size() + logits = torch.full([batch_size * next_n, max_model_len], float('-inf'), device=q.device, dtype=torch.float32) + context_lens = context_lens.tolist() + for i in range(batch_size): + context_len = context_lens[i] + q_offsets = torch.arange(context_len - next_n, context_len, device='cuda') + weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous() + for block_rk in range((context_len + block_size - 1) // block_size): + block_idx = block_tables[i][block_rk] + qx, kx = q[i], kv_cache[block_idx] + k_offsets = torch.arange(block_rk * block_size, (block_rk + 1) * block_size, device='cuda') + mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None]) + s = torch.where(mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(logits.dtype), float('-inf')) + s = torch.relu(s) * weight_slice[..., None] + s = s.sum(dim=0) + logits[i * next_n:(i + 1) * next_n, block_rk * block_size: (block_rk + 1) * block_size] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float('-inf')) + return logits \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..6db8c14e8 --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -0,0 +1,127 @@ +from functools import partial +from typing import override + +import torch +from sgl_kernel.flash_mla import flash_mla_sparse_fwd +from sgl_kernel.flash_attn import flash_attn_with_kvcache + +from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer +from lightllm.models.deepseek3_2.layer_infer.nsa_indexer_layer_inder import NSAIndexerInfer +from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionInferStateInfo +from lightllm.models.deepseek3_2.triton_kernel.token_group_quant import per_token_group_quant_mla_deep_gemm_masked_fp8 +from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward +from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd + + +class Deepseek3_2TransformerLayerInfer(Deepseek2TransformerLayerInfer): + def __init__(self, layer_num, network_config, mode=[]): + self.index_topk = network_config["index_topk"] + super().__init__(layer_num, network_config, mode) + + self.indexer = NSAIndexerInfer( + layer_idx=self.layer_num_, + network_config=self.network_config_, + mode=mode + ) + return + + @override + def _get_qkv( + self, + input: torch.Tensor, + infer_state: Deepseek3_2FlashAttentionInferStateInfo, + layer_weight: Deepseek3_2TransformerLayerWeight, + ) -> torch.Tensor: + input = input.view(-1, self.embed_dim_) + + if self.q_lora_rank is None: + q = layer_weight.q_weight_.mm(input) + cache_kv = layer_weight.kv_a_proj_with_mqa_.mm(input).view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) + else: + q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 + ) + q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) + + self.indexer.hidden_states = input + self.indexer.q_lora = q + + q = layer_weight.q_b_proj_.mm(q) + cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) + q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) + q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + rmsnorm_forward( + cache_kv[:, :, : self.kv_lora_rank], + weight=layer_weight.kv_a_layernorm_.weight, + eps=self.eps_, + out=cache_kv[:, :, : self.kv_lora_rank], + ) + + rotary_emb_fwd( + q_rope, + cache_kv[:, :, self.kv_lora_rank :], + infer_state.position_cos, + infer_state.position_sin, + ) + return q, cache_kv + + @override + def _bind_attention(self): + self._context_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._context_attention_flashmla_kernel_with_indexer, self) + self._token_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._token_attention_flashmla_kernel_with_indexer, self) + pass + + def _context_attention_flashmla_kernel_with_indexer( + self, + q: torch.Tensor, + kv, + infer_state: Deepseek3_2FlashInferStateInfo, + layer_weight: Deepseek3_2TransformerLayerWeight, + out=None, + ) -> torch.Tensor: + + q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + q_all = torch.cat([q_nope, q_rope], dim=-1) + topk_indices = self.indexer.get_indices( + infer_state, + layer_weight.indexer_layer_weight, + ) + mla_out, _, _ = flash_mla_sparse_fwd( + q=q_all, + kv=infer_state.mem_manager.kv_buffer[self.layer_num_], + indices=topk_indices.unsqueeze(1), + sm_scale=self.softmax_scale, + d_v=self.kv_lora_rank, + ) + return mla_out + + def _token_attention_flashmla_kernel_with_indexer( + self, q, infer_state: Deepseek3_2FlashAttentionInferStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None + ): + q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) + kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) + topk_indices = self.indexer.get_indices( + infer_state, + layer_weight.indexer_layer_weight, + ) + o = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope, + v_cache=kv_nope, + qv=q_nope, + page_table=topk_indices, + cache_seqlens=infer_state.b_att_seq_len, + cu_seqlens_q=infer_state.cu_seqlens_q, + cu_seqlens_k_new=infer_state.cu_seqlens_k, + max_seqlen_q=infer_state.max_q_seq_len, + softmax_scale=self.softmax_scale, + causal=True, + softcap=0.0, + return_softmax_lse=False, + num_splits=0, # TODO enable_deterministic_inference + ) diff --git a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py new file mode 100644 index 000000000..47e0bfdac --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py @@ -0,0 +1,49 @@ +from typing_extensions import override + +import torch + +from lightllm.common.basemodel.layer_weights.transformer_layer_weight import TransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, NormWeight + + +class NSAIndexerWeight(TransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode, quant_cfg): + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + return + + @override + def _init_weight(self): + prefix = f"model.layers.{self.layer_num_}.self_attn.indexer" + + self.wq_b_proj_ = ROWMMWeight( + weight_name=f"{prefix}.wq_b.weight", + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="wq_b", + tp_rank=0, + tp_world_size=1, + ) + self.wk_proj_ = ROWMMWeight( + weight_name=f"{prefix}.wk.weight", + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="wk", + tp_rank=0, + tp_world_size=1, + ) + self.k_norm_ = NormWeight( + f"{prefix}.k_norm.weight", + torch.float32, + bias_name=f"{prefix}.k_norm.bias" + ) + self.weights_proj_ = ROWMMWeight( + weight_name=f"{prefix}.weights_proj.weight", + data_type=self.data_type_, + quant_cfg=None, + layer_num=self.layer_num_, + name="weights_proj", + tp_rank=0, + tp_world_size=1, + ) diff --git a/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..2a03e1d6a --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py @@ -0,0 +1,16 @@ +from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight +from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight + + +class Deepseek3_2TransformerLayerWeight(Deepseek2TransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + self.index_topk = network_config["index_topk"] + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + self.indexer_layer_weight = NSAIndexerWeight( + layer_num=layer_num, + data_type=data_type, + network_config=network_config, + mode=mode, + quant_cfg=quant_cfg + ) + return diff --git a/lightllm/models/deepseek3_2/mem_manager.py b/lightllm/models/deepseek3_2/mem_manager.py new file mode 100644 index 000000000..0aa0a0bdb --- /dev/null +++ b/lightllm/models/deepseek3_2/mem_manager.py @@ -0,0 +1,47 @@ +from typing_extensions import override +import torch + +from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager + + +class Deepseek3_2MemoryManager(Deepseek2MemoryManager): + def __init__( + self, + size, + dtype, + head_num, + head_dim, + layer_num, + index_head_dim, + index_quant_block_size, + k_cache_dtype=torch.float8_e4m3fn, + k_scale_dtype=torch.float32, + always_copy=False, + mem_fraction=0.9 + ): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + assert index_head_dim % index_quant_block_size == 0, "index_head_dim must be divisible by index_quant_block_size" + self.index_head_dim = index_head_dim + self.index_quant_block_size = index_quant_block_size + self.k_cache_dtype = k_cache_dtype + self.k_scale_dtype = k_scale_dtype + return + + @override + def get_cell_size(self): + index_k_cache_cell_size = self.index_head_dim * self.layer_num * torch._utils._element_size(self.k_cache_dtype) + index_k_scale_cell_size = (self.index_head_dim // self.index_quant_block_size) * self.layer_num * torch._utils._element_size(self.k_scale_dtype) + return super().get_cell_size() + index_k_cache_cell_size + index_k_scale_cell_size + + @override + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + super()._init_buffers(size, dtype, head_num, head_dim, layer_num) + self._init_indexer_k_cache_buffers() + return + + def _init_indexer_k_cache_buffers(self): + self.indexer_k_cache_buffers = torch.empty( + (self.layer_num, self.size + 1, self.index_head_dim), dtype=self.k_cache_dtype, device="cuda") + self.indexer_k_scale_buffers = torch.empty( + (self.layer_num, self.size + 1, self.index_head_dim // self.index_quant_block_size), dtype=self.k_scale_dtype, device="cuda") + return diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py new file mode 100644 index 000000000..3a244c77f --- /dev/null +++ b/lightllm/models/deepseek3_2/model.py @@ -0,0 +1,38 @@ +from lightllm.models.registry import ModelRegistry +from lightllm.models.deepseek2.model import Deepseek2TpPartModel +from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight +from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashInferStateInfo + +@ModelRegistry(["deepseek_v32"]) +class Deepseek3_2TpPartModel(Deepseek2TpPartModel): + # weight class + transformer_weight_class = Deepseek3_2TransformerLayerWeight + + # infer class + transformer_layer_infer_class = Deepseek3_2TransformerLayerInfer + + # infer state class + infer_state_class = Deepseek3_2FlashInferStateInfo + + def _init_mem_manager(self): + # mtp 模式下需要在mem manger上扩展draft model使用的layer + added_mtp_layer_num = 0 + if get_env_start_args().mtp_mode == "deepseekv3_eagle": + added_mtp_layer_num += 1 + elif get_env_start_args().mtp_mode == "deepseekv3_vanilla": + added_mtp_layer_num += get_env_start_args().mtp_step + + self.mem_manager = Deepseek3_2MemoryManager( + self.max_total_token_num, + dtype=self.data_type, + head_num=1, + head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"], + layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, + index_head_dim = self.config["index_head_dim"], + index_quant_block_size = self.config["index_quant_block_size"], + mem_fraction=self.mem_fraction, + ) + return \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/triton_kernel/__init__.py b/lightllm/models/deepseek3_2/triton_kernel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/deepseek3_2/triton_kernel/act_quant.py b/lightllm/models/deepseek3_2/triton_kernel/act_quant.py new file mode 100644 index 000000000..a4ecd0f51 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/act_quant.py @@ -0,0 +1,137 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/ce6b17c0f94e6bf53633c8f324176a891e67fa7f/python/sglang/srt/layers/attention/nsa/triton_kernel.py +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + + +# Triton implementation +@triton.jit +def _act_quant_kernel( + X_ptr, + Y_ptr, + S_ptr, + M, + N, + group_size: tl.constexpr, + round_scale: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + Triton kernel for activation quantization. + + Each block processes BLOCK_M rows and group_size columns. + """ + # Get block IDs + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + # FP8 constants + fp8_min = -448.0 + fp8_max = 448.0 + fp8_max_inv = 1.0 / fp8_max + + # Calculate row and column offsets + row_start = pid_m * BLOCK_M + col_start = pid_n * group_size + + # Create offset arrays + rows = row_start + tl.arange(0, BLOCK_M) + cols = col_start + tl.arange(0, BLOCK_N) + + # Mask for valid rows and columns + row_mask = rows < M + col_mask = cols < N + mask = row_mask[:, None] & col_mask[None, :] + + # Load input data + x_ptrs = X_ptr + rows[:, None] * N + cols[None, :] + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + + # Compute absolute max along columns (group_size dimension) for each row + x_abs = tl.abs(x) + amax = tl.max(x_abs, axis=1) # Shape: (BLOCK_M,) + + # Clamp amax to avoid division by zero + amax = tl.maximum(amax, 1e-4) + + # Compute scale + if round_scale: + # Fast round scale using bit manipulation approximation + # This is a simplified version - the exact bit manipulation is harder in Triton + # Using log2 + ceil + pow2 as approximation + log_val = tl.log2(amax * fp8_max_inv) + log_ceil = tl.ceil(log_val) + scale = tl.exp2(log_ceil) + else: + scale = amax * fp8_max_inv + + # Quantize: y = clamp(x / scale, fp8_min, fp8_max) + scale_broadcast = scale[:, None] + y = x / scale_broadcast + y = tl.minimum(tl.maximum(y, fp8_min), fp8_max) + + # Store quantized output + y_ptrs = Y_ptr + rows[:, None] * N + cols[None, :] + tl.store(y_ptrs, y, mask=mask) + + # Store scales + s_cols = pid_n + s_ptrs = S_ptr + rows * (N // group_size) + s_cols + s_mask = row_mask + tl.store(s_ptrs, scale, mask=s_mask) + + +def act_quant( + x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization with Triton. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + scale_fmt (Optional[str], optional): The format of the scale. Default is None. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert ( + x.size(-1) % block_size == 0 + ), f"Last dimension size must be divisible by block_size (block_size={block_size})" + + # Flatten all dims except last + N = x.size(-1) + x_flat = x.view(-1, N) + M = x_flat.size(0) + + # Allocate output tensors + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + y_flat = y.view(-1, N) + s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32) + s_flat = s.view(-1, N // block_size) + + # Launch kernel + BLOCK_M = 32 + BLOCK_N = block_size + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, block_size)) + round_scale = scale_fmt is not None + + _act_quant_kernel[grid]( + x_flat, + y_flat, + s_flat, + M, + N, + group_size=block_size, + round_scale=round_scale, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_stages=0 if round_scale else 2, + ) + + return y, s diff --git a/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py b/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py new file mode 100644 index 000000000..dbf5c5199 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py @@ -0,0 +1,103 @@ +import triton +import triton.language as tl +import torch +from typing import Tuple + +fp8_min = -448.0 +fp8_max = 448.0 +fp8_dtype = torch.float8_e4m3fn + +@triton.jit +def _per_token_group_quant_mla_deep_gemm_masked_fp8( + y_ptr, + y_q_ptr, + y_s_ptr, + masked_m_ptr, + group_size, + y_stride_b, + y_stride_t, + y_q_stride_b, + y_q_stride_t, + y_s_stride_b, + y_s_stride_g, + eps, + fp8_min, + fp8_max, + NUM_GROUP: tl.constexpr, + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor for deep_gemm grouped_gemm_masked. + This function converts the tensor values into float8 values. + y and y_q: (b, t, k) + y_s: (b, k//group_size, t) + """ + t_id = tl.program_id(0) + b_id = tl.program_id(1) + + y_ptr += b_id * y_stride_b + t_id * y_stride_t + y_q_ptr += b_id * y_q_stride_b + t_id * y_q_stride_t + y_s_ptr += b_id * y_s_stride_b + t_id + + if t_id == 0: + tl.store(masked_m_ptr + b_id, tl.num_programs(0)) + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + for gid in range(NUM_GROUP): + y = tl.load(y_ptr + gid * group_size + cols, mask=mask, other=0.0).to( + tl.float32 + ) + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + gid * group_size + cols, y_q, mask=mask) + tl.store(y_s_ptr + gid * y_s_stride_g, y_s) + + +def per_token_group_quant_mla_deep_gemm_masked_fp8( + x: torch.Tensor, + group_size: int = 128, + eps: float = 1e-12, + dtype: torch.dtype = fp8_dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This function quantizes input values to float8 values with per-token-group-quantization + for deep_gemm grouped_gemm_masked and specialized for mla absorbed case. + """ + assert x.dim() == 3, "`x` is not a 3d-tensor" + + b, m, k = x.shape + aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel + num_tiles_k = k // group_size + assert num_tiles_k * group_size == k, f"k % {group_size} must be zero" + + x_q = x.new_empty((b, aligned_m, k), dtype=dtype) + x_s = x.new_empty((b, num_tiles_k, aligned_m), dtype=torch.float32) + masked_m = x.new_empty((b,), dtype=torch.int32) + + BLOCK_SIZE = triton.next_power_of_2(group_size) + grid = (m, b) + + _per_token_group_quant_mla_deep_gemm_masked_fp8[grid]( + x, + x_q, + x_s, + masked_m, + group_size, + x.stride(0), + x.stride(1), + x_q.stride(0), + x_q.stride(1), + x_s.stride(0), + x_s.stride(1), + eps, + -fp8_max, + fp8_max, + num_tiles_k, + BLOCK_SIZE, + ) + + return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m \ No newline at end of file From 03aeae67aa0dbe76df300373742c5bd03df50fae Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 5 Nov 2025 06:22:57 +0000 Subject: [PATCH 02/12] fix --- lightllm/models/deepseek3_2/infer_struct.py | 7 ++- .../layer_infer/nsa_indexer_layer_inder.py | 9 ++- .../layer_infer/transformer_layer_infer.py | 2 +- lightllm/models/deepseek3_2/mem_manager.py | 63 +++++++++++++------ lightllm/models/deepseek3_2/model.py | 2 - 5 files changed, 57 insertions(+), 26 deletions(-) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index 6e5e766b2..20f8b7e8d 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -3,7 +3,12 @@ import numpy as np import torch.distributed as dist from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo +from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2IndexerPagedMemoryManager, Deepseek3_2MemoryManager class Deepseek3_2FlashAttentionInferStateInfo(Deepseek2FlashAttentionStateInfo): - pass \ No newline at end of file + + def __init__(self): + super().__init__() + assert isinstance(self.req_manager.mem_manager, Deepseek3_2MemoryManager) + self.indexer_paged_mem_manager : Deepseek3_2IndexerPagedMemoryManager = self.req_manager.mem_manager.indexer_paged_mem_manager diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index a3891f0f3..100df16f9 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -28,7 +28,7 @@ def __init__(self, layer_idx, network_config, mode=[]): self.scale_fmt = network_config["quantization_config"]["scale_fmt"] self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) self.index_n_heads = network_config["index_n_heads"] - self.index_n_heads_scale = self.index_n_heads ** -0.5 + self.index_n_heads_scale = (self.index_n_heads ** -0.5) * self.softmax_scale self.q_lora = None self.hidden_states = None @@ -67,8 +67,13 @@ def get_indices(self, infer_state: Deepseek3_2FlashAttentionInferStateInfo, laye q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) + # write + # infer_state.mem_manager. + + # read + weights = layer_weight.weights_proj_.mm(self.hidden_states) * self.index_n_heads_scale - weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale + weights = weights.unsqueeze(-1) * q_scale logits = fp8_paged_mqa_logits_torch( q_fp8, k_fp8, weights, diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 6db8c14e8..9f503e9bd 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -76,7 +76,7 @@ def _context_attention_flashmla_kernel_with_indexer( self, q: torch.Tensor, kv, - infer_state: Deepseek3_2FlashInferStateInfo, + infer_state: Deepseek3_2FlashAttentionInferStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None, ) -> torch.Tensor: diff --git a/lightllm/models/deepseek3_2/mem_manager.py b/lightllm/models/deepseek3_2/mem_manager.py index 0aa0a0bdb..f2613aacc 100644 --- a/lightllm/models/deepseek3_2/mem_manager.py +++ b/lightllm/models/deepseek3_2/mem_manager.py @@ -1,9 +1,37 @@ from typing_extensions import override import torch +from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager +from lightllm.utils.log_utils import init_logger +logger = init_logger(__name__) +class Deepseek3_2IndexerPagedMemoryManager: + def __init__(self, page_size): + self.page_size = page_size + return + + def set_size(self, size): + self.physics_size = size + self.num_pages = size // self.page_size + return + + def _init_buffers(self): + self.k_cache_buffer = torch.empty( + (self.page_size, 128), dtype=torch.float8_e4m3fn, device="cuda") + self.k_scale_buffer = torch.empty( + (self.page_size, 1), dtype=torch.float64, device="cuda") + return + + def alloc_paged_index(self, last_index: int, need_size): + pass + + def get_cell_size(self): + # Use for deepseek v3.2 exp only, 128 for k_cache(128 torch.float8_e4m3fn), 4 for scale(1 torch.float64) + return 128 + 4 + + class Deepseek3_2MemoryManager(Deepseek2MemoryManager): def __init__( self, @@ -12,36 +40,31 @@ def __init__( head_num, head_dim, layer_num, - index_head_dim, - index_quant_block_size, - k_cache_dtype=torch.float8_e4m3fn, - k_scale_dtype=torch.float32, always_copy=False, - mem_fraction=0.9 + mem_fraction=0.9, + page_size=64 ): + self.page_size = page_size + self.indexer_paged_mem_manager = Deepseek3_2IndexerPagedMemoryManager(page_size) super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) - assert index_head_dim % index_quant_block_size == 0, "index_head_dim must be divisible by index_quant_block_size" - self.index_head_dim = index_head_dim - self.index_quant_block_size = index_quant_block_size - self.k_cache_dtype = k_cache_dtype - self.k_scale_dtype = k_scale_dtype + self.indexer_paged_mem_manager.set_size(self.size) return @override def get_cell_size(self): - index_k_cache_cell_size = self.index_head_dim * self.layer_num * torch._utils._element_size(self.k_cache_dtype) - index_k_scale_cell_size = (self.index_head_dim // self.index_quant_block_size) * self.layer_num * torch._utils._element_size(self.k_scale_dtype) - return super().get_cell_size() + index_k_cache_cell_size + index_k_scale_cell_size + return super().get_cell_size() + self.indexer_paged_mem_manager.get_cell_size() @override def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): super()._init_buffers(size, dtype, head_num, head_dim, layer_num) - self._init_indexer_k_cache_buffers() + self.indexer_paged_mem_manager._init_buffers() return - def _init_indexer_k_cache_buffers(self): - self.indexer_k_cache_buffers = torch.empty( - (self.layer_num, self.size + 1, self.index_head_dim), dtype=self.k_cache_dtype, device="cuda") - self.indexer_k_scale_buffers = torch.empty( - (self.layer_num, self.size + 1, self.index_head_dim // self.index_quant_block_size), dtype=self.k_scale_dtype, device="cuda") - return + @override + def profile_size(self, mem_fraction): + super().profile_size(mem_fraction) + if self.size % self.page_size != 0: + size_paged = (self.size // self.page_size + 1) * self.page_size + logger.warning(f"size {self.size} is not divisible by page_size {self.page_size}, will use paged_size {size_paged}") + self.size = size_paged + return \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index 3a244c77f..5b3fc1f13 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -31,8 +31,6 @@ def _init_mem_manager(self): head_num=1, head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"], layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, - index_head_dim = self.config["index_head_dim"], - index_quant_block_size = self.config["index_quant_block_size"], mem_fraction=self.mem_fraction, ) return \ No newline at end of file From f172db907401bf0ce7d39492df0346283f3c86f4 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 6 Nov 2025 10:40:46 +0000 Subject: [PATCH 03/12] fix --- lightllm/common/basemodel/basemodel.py | 2 +- .../common/deepseek2_fp8kv_mem_manager.py | 4 +- lightllm/common/deepseek2_mem_manager.py | 4 +- lightllm/common/mem_manager.py | 17 ++- lightllm/models/__init__.py | 1 + lightllm/models/deepseek3_2/infer_struct.py | 26 +++- .../layer_infer/nsa_indexer_layer_inder.py | 136 ++++++++++------- .../layer_infer/transformer_layer_infer.py | 15 +- lightllm/models/deepseek3_2/mem_manager.py | 72 ++------- lightllm/models/deepseek3_2/model.py | 13 +- .../destindex_copy_indexer_ks.py | 137 ++++++++++++++++++ .../triton_kernel/fp8_mqa_logits.py | 0 12 files changed, 286 insertions(+), 141 deletions(-) create mode 100644 lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py create mode 100644 lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 77ca299b2..5be45c6b1 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -110,7 +110,7 @@ def __init__(self, kvargs): self._init_some_value() self._init_custom() self._init_inferstate_cls() - self._autotune_warmup() + # self._autotune_warmup() self._init_padded_req() # wait必须在init cudagraph 之前,避免错误捕获 self._wait_other_modules_ready() diff --git a/lightllm/common/deepseek2_fp8kv_mem_manager.py b/lightllm/common/deepseek2_fp8kv_mem_manager.py index 00699f4b1..ffa0f2274 100644 --- a/lightllm/common/deepseek2_fp8kv_mem_manager.py +++ b/lightllm/common/deepseek2_fp8kv_mem_manager.py @@ -3,6 +3,6 @@ class Deepseek2FP8KVMemoryManager(Deepseek2MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False): # scale被追加到kv_buffer末尾, 因此加2, dtype统一改成uint8 - super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction) + super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction, is_sub_mem_manager) diff --git a/lightllm/common/deepseek2_mem_manager.py b/lightllm/common/deepseek2_mem_manager.py index 4f106bdcf..46fd87dcf 100644 --- a/lightllm/common/deepseek2_mem_manager.py +++ b/lightllm/common/deepseek2_mem_manager.py @@ -14,8 +14,8 @@ class Deepseek2MemoryManager(MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): - super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_sub_mem_manager) def get_cell_size(self): return self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index 57ae9838b..167ce2760 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -19,7 +19,7 @@ class MemoryManager: - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False): self.size = size self.head_num = head_num self.head_dim = head_dim @@ -41,15 +41,16 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False self.can_use_mem_size = self.size - # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 - from lightllm.utils.envs_utils import get_unique_server_name + if not is_sub_mem_manager: + # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 + from lightllm.utils.envs_utils import get_unique_server_name - rank_in_node = get_current_rank_in_node() - self.shared_can_use_token_num = SharedInt( - f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" - ) + rank_in_node = get_current_rank_in_node() + self.shared_can_use_token_num = SharedInt( + f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" + ) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) self._init_buffers( self.size, dtype, diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 96329eabe..9c9bdb9df 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -20,6 +20,7 @@ from lightllm.models.gemma_2b.model import Gemma_2bTpPartModel from lightllm.models.phi3.model import Phi3TpPartModel from lightllm.models.deepseek2.model import Deepseek2TpPartModel +from lightllm.models.deepseek3_2.model import Deepseek3_2TpPartModel from lightllm.models.internvl.model import ( InternVLLlamaTpPartModel, InternVLPhi3TpPartModel, diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index 20f8b7e8d..bfdb53fd6 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -1,14 +1,24 @@ -import os import torch -import numpy as np -import torch.distributed as dist from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo -from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2IndexerPagedMemoryManager, Deepseek3_2MemoryManager +class Deepseek3_2FlashAttentionStateInfo(Deepseek2FlashAttentionStateInfo): -class Deepseek3_2FlashAttentionInferStateInfo(Deepseek2FlashAttentionStateInfo): - def __init__(self): super().__init__() - assert isinstance(self.req_manager.mem_manager, Deepseek3_2MemoryManager) - self.indexer_paged_mem_manager : Deepseek3_2IndexerPagedMemoryManager = self.req_manager.mem_manager.indexer_paged_mem_manager + self.lengths = None + self.page_table_size_1 = None + self.ks = None + self.ke = None + return + + def init_some_extra_state(self, model, input_ids: torch.Tensor): + super().init_some_extra_state(model, input_ids) + # Ensure b_ready_cache_len is set for both prefill and decode modes + if self.is_prefill: + # b_ready_cache_len is already set in basemodel.py for prefill + pass + else: + # In decode mode, b_ready_cache_len should be b_seq_len - b_q_seq_len + # since b_q_seq_len represents the new tokens being processed + if self.b_ready_cache_len is None: + self.b_ready_cache_len = self.b_seq_len - self.b_q_seq_len diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 100df16f9..1977c211e 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -5,10 +5,12 @@ from lightllm.common.basemodel.layer_infer.base_layer_infer import BaseLayerInfer from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight -from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionInferStateInfo +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant - +from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager +from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks +# from lightllm.models.deepseek3_2.triton_kernel.fp8_mqa_logits import fp8_mqa_logits class NSAIndexerInfer(BaseLayerInfer): def __init__(self, layer_idx, network_config, mode=[]): @@ -30,8 +32,6 @@ def __init__(self, layer_idx, network_config, mode=[]): self.index_n_heads = network_config["index_n_heads"] self.index_n_heads_scale = (self.index_n_heads ** -0.5) * self.softmax_scale - self.q_lora = None - self.hidden_states = None return def ref_fp8_mqa_logits(self, q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, @@ -59,7 +59,7 @@ def ref_fp8_mqa_logits(self, q: torch.Tensor, kv: torch.Tensor, weights: torch.T cost = mask.sum() return logits, cost - def get_indices(self, infer_state: Deepseek3_2FlashAttentionInferStateInfo, layer_weight: NSAIndexerWeight) -> torch.Tensor: + def get_indices(self, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight) -> torch.Tensor: assert self.hidden_states is not None assert self.q_lora is not None @@ -67,29 +67,78 @@ def get_indices(self, infer_state: Deepseek3_2FlashAttentionInferStateInfo, laye q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) - # write - # infer_state.mem_manager. - - # read + self._copy_ks_to_mem_cache(k_fp8, k_scale, infer_state.mem_index, infer_state.mem_manager) weights = layer_weight.weights_proj_.mm(self.hidden_states) * self.index_n_heads_scale - weights = weights.unsqueeze(-1) * q_scale - - logits = fp8_paged_mqa_logits_torch( - q_fp8, k_fp8, weights, - infer_state.lengths, - infer_state.page_table, - infer_state.max_model_len + weights = weights.unsqueeze(-1) * q_scale + + ks_buffer = infer_state.mem_manager.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] + + k_fp8_list = [] + k_scale_list = [] + ks_list = [] + ke_list = [] + offset = 0 + for i in range(infer_state.batch_size): + q_len = infer_state.b_q_seq_len[i] + cache_len = infer_state.b_ready_cache_len[i] + mem_indexes = infer_state.req_manager.req_to_token_indexs[infer_state.b_req_idx[i], :cache_len+q_len] + k_fp8 = ks_buffer[mem_indexes, 0, :128].view(torch.float8_e4m3fn).contiguous() + k_scale = ks_buffer[mem_indexes, 0, 128:].view(torch.float32).contiguous() + ks = torch.full((q_len,), offset, dtype=torch.int32, device="cuda") + ke = ks + torch.arange(q_len, dtype=torch.int32, device="cuda") + 1 + k_fp8_list.append(k_fp8) + k_scale_list.append(k_scale) + ks_list.append(ks) + ke_list.append(ke) + offset += q_len + + k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn) + k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1) + kv_fp8 = (k_fp8, k_scale) + ks = torch.cat(ks_list, dim=0) + ke = torch.cat(ke_list, dim=0) + + logits = deep_gemm.fp8_mqa_logits( + q_fp8, + kv_fp8, + weights.squeeze(-1), + ks, + ke, + clean_logits=False, ) - return fast_topk_transform_fused( - score=logits, - lengths=infer_state.lengths, - page_table_size_1=infer_state.page_table, - cu_seqlens_q=infer_state.b1_cu_q_seq_len, - topk=self.index_topk - ) - + return self.get_topk(logits, infer_state) + + def get_topk(self, logits, infer_state: Deepseek3_2FlashAttentionStateInfo): + topk_indices_list = [] + offset = 0 + + for i in range(infer_state.batch_size): + q_len = infer_state.b_q_seq_len[i] + cache_len = infer_state.b_ready_cache_len[i] + end_pos = q_len + cache_len + # Slice logits for this batch (both query and sequence dimensions) + batch_logits = logits[offset:offset + q_len, :end_pos] + topk_indices = batch_logits.topk(min(self.index_topk, end_pos), dim=-1)[1] + mem_indexes = infer_state.req_manager.req_to_token_indexs[infer_state.b_req_idx[i], :cache_len+q_len] + indices = torch.full((q_len, self.index_topk), -1, dtype=torch.int32, device="cuda") + for j in range(q_len): + indices[j, :topk_indices[j].shape[0]] = mem_indexes[topk_indices[j]] + topk_indices_list.append(indices) + offset += q_len + + topk_indices_ = torch.cat(topk_indices_list, dim=0) + + return topk_indices_ + + + def get_k_float32_from_buffer(self, buffer: torch.Tensor): + k_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn) + k_scale = buffer[:, :, 128:].view(torch.float32)[:, :, :1] + k_float32 = k_fp8.float() * k_scale + return k_float32 + @staticmethod def _rotate_activation(x: torch.Tensor) -> torch.Tensor: assert x.dtype == torch.bfloat16 @@ -101,12 +150,11 @@ def _rotate_activation(x: torch.Tensor) -> torch.Tensor: ) == 0, "Hidden size must be a power of 2 for Hadamard transform." return hadamard_transform(x, scale=hidden_size**-0.5) - def _get_q_k_bf16(self, infer_state: Deepseek3_2FlashAttentionInferStateInfo, layer_weight: NSAIndexerWeight): + def _get_q_k_bf16(self, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight): q = layer_weight.wq_b_proj_.mm(self.q_lora).view(-1, self.index_n_heads, self.index_head_dim) self.q_lora = None k = layer_weight.wk_proj_.mm(self.hidden_states) - self.hidden_states = None k = F.layer_norm( k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps ).type_as(k) @@ -122,26 +170,16 @@ def _get_q_k_bf16(self, infer_state: Deepseek3_2FlashAttentionInferStateInfo, la k = self._rotate_activation(k) return q, k - -# TODO -def fp8_paged_mqa_logits_torch(q: torch.Tensor, kv_cache: torch.Tensor, - weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor, - max_model_len: int): - batch_size, next_n, heads, dim = q.size() - num_block, block_size, _, dim = kv_cache.size() - logits = torch.full([batch_size * next_n, max_model_len], float('-inf'), device=q.device, dtype=torch.float32) - context_lens = context_lens.tolist() - for i in range(batch_size): - context_len = context_lens[i] - q_offsets = torch.arange(context_len - next_n, context_len, device='cuda') - weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous() - for block_rk in range((context_len + block_size - 1) // block_size): - block_idx = block_tables[i][block_rk] - qx, kx = q[i], kv_cache[block_idx] - k_offsets = torch.arange(block_rk * block_size, (block_rk + 1) * block_size, device='cuda') - mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None]) - s = torch.where(mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(logits.dtype), float('-inf')) - s = torch.relu(s) * weight_slice[..., None] - s = s.sum(dim=0) - logits[i * next_n:(i + 1) * next_n, block_rk * block_size: (block_rk + 1) * block_size] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float('-inf')) - return logits \ No newline at end of file + def _copy_ks_to_mem_cache(self, k_fp8, k_scale, mem_index, mem_manager: Deepseek3_2MemoryManager): + # k_fp8 : [seq_len, 128] torch.fp8_e4m3 + # k_scale : [seq_len, 1] torch.float32 + # mem_index : [seq_len] torch.int32 + # buffer : [10000000, 1, 132] torch.uint8 + buffer = mem_manager.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] + destindex_copy_indexer_ks( + k_fp8.unsqueeze(1), # Add head dimension: [seq_len, 1, 128] + k_scale.unsqueeze(1), # Add head dimension: [seq_len, 1, 1] + mem_index, + buffer + ) + return \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 9f503e9bd..076d3965c 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -8,7 +8,7 @@ from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer from lightllm.models.deepseek3_2.layer_infer.nsa_indexer_layer_inder import NSAIndexerInfer from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight -from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionInferStateInfo +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo from lightllm.models.deepseek3_2.triton_kernel.token_group_quant import per_token_group_quant_mla_deep_gemm_masked_fp8 from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd @@ -30,7 +30,7 @@ def __init__(self, layer_num, network_config, mode=[]): def _get_qkv( self, input: torch.Tensor, - infer_state: Deepseek3_2FlashAttentionInferStateInfo, + infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, ) -> torch.Tensor: input = input.view(-1, self.embed_dim_) @@ -68,6 +68,7 @@ def _get_qkv( @override def _bind_attention(self): + super()._bind_attention() self._context_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._context_attention_flashmla_kernel_with_indexer, self) self._token_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._token_attention_flashmla_kernel_with_indexer, self) pass @@ -76,7 +77,7 @@ def _context_attention_flashmla_kernel_with_indexer( self, q: torch.Tensor, kv, - infer_state: Deepseek3_2FlashAttentionInferStateInfo, + infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None, ) -> torch.Tensor: @@ -87,18 +88,19 @@ def _context_attention_flashmla_kernel_with_indexer( topk_indices = self.indexer.get_indices( infer_state, layer_weight.indexer_layer_weight, - ) + ).unsqueeze(1) + mla_out, _, _ = flash_mla_sparse_fwd( q=q_all, kv=infer_state.mem_manager.kv_buffer[self.layer_num_], - indices=topk_indices.unsqueeze(1), + indices=topk_indices, sm_scale=self.softmax_scale, d_v=self.kv_lora_rank, ) return mla_out def _token_attention_flashmla_kernel_with_indexer( - self, q, infer_state: Deepseek3_2FlashAttentionInferStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None + self, q, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None ): q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) @@ -125,3 +127,4 @@ def _token_attention_flashmla_kernel_with_indexer( return_softmax_lse=False, num_splits=0, # TODO enable_deterministic_inference ) + return o \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/mem_manager.py b/lightllm/models/deepseek3_2/mem_manager.py index f2613aacc..a70c76273 100644 --- a/lightllm/models/deepseek3_2/mem_manager.py +++ b/lightllm/models/deepseek3_2/mem_manager.py @@ -1,70 +1,22 @@ +from typing import List from typing_extensions import override import torch -from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.common.mem_manager import MemoryManager from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager -from lightllm.utils.log_utils import init_logger +from lightllm.server.pd_io_struct import KVMoveTask +from lightllm.distributed.pynccl import PyNcclCommunicator -logger = init_logger(__name__) - -class Deepseek3_2IndexerPagedMemoryManager: - def __init__(self, page_size): - self.page_size = page_size - return - - def set_size(self, size): - self.physics_size = size - self.num_pages = size // self.page_size - return - - def _init_buffers(self): - self.k_cache_buffer = torch.empty( - (self.page_size, 128), dtype=torch.float8_e4m3fn, device="cuda") - self.k_scale_buffer = torch.empty( - (self.page_size, 1), dtype=torch.float64, device="cuda") - return - - def alloc_paged_index(self, last_index: int, need_size): - pass - - def get_cell_size(self): - # Use for deepseek v3.2 exp only, 128 for k_cache(128 torch.float8_e4m3fn), 4 for scale(1 torch.float64) - return 128 + 4 - - class Deepseek3_2MemoryManager(Deepseek2MemoryManager): - def __init__( - self, - size, - dtype, - head_num, - head_dim, - layer_num, - always_copy=False, - mem_fraction=0.9, - page_size=64 - ): - self.page_size = page_size - self.indexer_paged_mem_manager = Deepseek3_2IndexerPagedMemoryManager(page_size) - super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) - self.indexer_paged_mem_manager.set_size(self.size) + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9 ,is_sub_mem_manager=False): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_sub_mem_manager) + self.indexer_ks_mem_manager = Deepseek2MemoryManager(self.size, torch.uint8, 1, 132, layer_num, is_sub_mem_manager=True) return @override def get_cell_size(self): - return super().get_cell_size() + self.indexer_paged_mem_manager.get_cell_size() - - @override - def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - super()._init_buffers(size, dtype, head_num, head_dim, layer_num) - self.indexer_paged_mem_manager._init_buffers() - return - - @override - def profile_size(self, mem_fraction): - super().profile_size(mem_fraction) - if self.size % self.page_size != 0: - size_paged = (self.size // self.page_size + 1) * self.page_size - logger.warning(f"size {self.size} is not divisible by page_size {self.page_size}, will use paged_size {size_paged}") - self.size = size_paged - return \ No newline at end of file + return super().get_cell_size() + 132 + +class Deepseek3_2FP8KVMemoryManager(Deepseek3_2MemoryManager): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False): + super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction, is_sub_mem_manager) \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index 5b3fc1f13..c4e56c3c1 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -3,9 +3,8 @@ from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer from lightllm.utils.envs_utils import get_env_start_args -from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager -from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashInferStateInfo - +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo +from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager, Deepseek3_2FP8KVMemoryManager @ModelRegistry(["deepseek_v32"]) class Deepseek3_2TpPartModel(Deepseek2TpPartModel): # weight class @@ -15,9 +14,13 @@ class Deepseek3_2TpPartModel(Deepseek2TpPartModel): transformer_layer_infer_class = Deepseek3_2TransformerLayerInfer # infer state class - infer_state_class = Deepseek3_2FlashInferStateInfo + infer_state_class = Deepseek3_2FlashAttentionStateInfo def _init_mem_manager(self): + manager_class = Deepseek3_2MemoryManager + if "triton_fp8kv" in self.mode: + manager_class = Deepseek3_2FP8KVMemoryManager + # mtp 模式下需要在mem manger上扩展draft model使用的layer added_mtp_layer_num = 0 if get_env_start_args().mtp_mode == "deepseekv3_eagle": @@ -25,7 +28,7 @@ def _init_mem_manager(self): elif get_env_start_args().mtp_mode == "deepseekv3_vanilla": added_mtp_layer_num += get_env_start_args().mtp_step - self.mem_manager = Deepseek3_2MemoryManager( + self.mem_manager = manager_class( self.max_total_token_num, dtype=self.data_type, head_num=1, diff --git a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py new file mode 100644 index 000000000..a098795fb --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py @@ -0,0 +1,137 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_destindex_copy_indexer_ks( + k_fp8, + k_scale, + mem_index, + buffer_fp8, + buffer_scale, + stride_k_fp8_bs, + stride_k_fp8_h, + stride_k_fp8_d, + stride_k_scale_bs, + stride_k_scale_h, + stride_k_scale_d, + stride_buffer_fp8_bs, + stride_buffer_fp8_h, + stride_buffer_fp8_d, + stride_buffer_scale_bs, + stride_buffer_scale_h, + stride_buffer_scale_d, + head_num, + BLOCK_DMODEL: tl.constexpr, + BLOCK_HEAD: tl.constexpr, +): + cur_index = tl.program_id(0) + offs_h = tl.arange(0, BLOCK_HEAD) + offs_d = tl.arange(0, BLOCK_DMODEL) + + dest_index = tl.load(mem_index + cur_index).to(tl.int64) + + # Load k_fp8 data + k_fp8_ptrs = k_fp8 + cur_index * stride_k_fp8_bs + stride_k_fp8_h * offs_h[:, None] + stride_k_fp8_d * offs_d[None, :] + k_fp8_data = tl.load(k_fp8_ptrs, mask=offs_h[:, None] < head_num, other=0.0) + + # Load k_scale data + k_scale_ptrs = k_scale + cur_index * stride_k_scale_bs + stride_k_scale_h * offs_h[:, None] + stride_k_scale_d * tl.arange(0, 1)[None, :] + k_scale_data = tl.load(k_scale_ptrs, mask=offs_h[:, None] < head_num, other=0.0) + + # Store k_fp8 to buffer_fp8 + buffer_fp8_ptrs = buffer_fp8 + dest_index * stride_buffer_fp8_bs + stride_buffer_fp8_h * offs_h[:, None] + stride_buffer_fp8_d * offs_d[None, :] + tl.store(buffer_fp8_ptrs, k_fp8_data, mask=offs_h[:, None] < head_num) + + # Store k_scale to buffer_scale + buffer_scale_ptrs = buffer_scale + dest_index * stride_buffer_scale_bs + stride_buffer_scale_h * offs_h[:, None] + stride_buffer_scale_d * tl.arange(0, 1)[None, :] + tl.store(buffer_scale_ptrs, k_scale_data, mask=offs_h[:, None] < head_num) + + +@torch.no_grad() +def destindex_copy_indexer_ks(k_fp8, k_scale, mem_index, buffer): + seq_len = mem_index.shape[0] + head_num = k_fp8.shape[1] + k_fp8_dim = k_fp8.shape[2] # Should be 128 for float8 + k_scale_dim = k_scale.shape[2] # Should be 1 + + assert k_fp8.shape[1] == k_scale.shape[1] + assert k_fp8_dim == 128, f"k_fp8 dim should be 128, got {k_fp8_dim}" + assert k_scale_dim == 1, f"k_scale dim should be 1, got {k_scale_dim}" + assert buffer.shape[2] == 132, f"buffer dim should be 132, got {buffer.shape[2]}" # 128 + 4 bytes + + # Reinterpret buffer as the appropriate types for storing + buffer_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn) + buffer_scale = buffer[:, :, 128:132].view(torch.float32)[:, :, :1] + + BLOCK_HEAD = triton.next_power_of_2(head_num) + grid = (seq_len,) + num_warps = 1 + + _fwd_kernel_destindex_copy_indexer_ks[grid]( + k_fp8, + k_scale, + mem_index, + buffer_fp8, + buffer_scale, + k_fp8.stride(0), + k_fp8.stride(1), + k_fp8.stride(2), + k_scale.stride(0), + k_scale.stride(1), + k_scale.stride(2), + buffer_fp8.stride(0), + buffer_fp8.stride(1), + buffer_fp8.stride(2), + buffer_scale.stride(0), + buffer_scale.stride(1), + buffer_scale.stride(2), + head_num, + BLOCK_DMODEL=k_fp8_dim, + BLOCK_HEAD=BLOCK_HEAD, + num_warps=num_warps, + num_stages=1, + ) + return + + +def test(): + import torch.nn.functional as F + + # Test parameters similar to the usage in nsa_indexer_layer_inder.py + B, N_CTX, H, K_DIM = 4, 1024, 8, 128 # batch_size, seq_len, heads, k_dim + seq_len = 50 # number of tokens to copy + dtype_fp8 = torch.float8_e4m3fn + dtype_scale = torch.float32 + + # Create test data + k_fp8 = torch.randn((seq_len, H, K_DIM), dtype=dtype_fp8).cuda() + k_scale = torch.randn((seq_len, H, 1), dtype=dtype_scale).cuda() + mem_index = torch.randint(0, B * N_CTX, (seq_len,), dtype=torch.int32).cuda() + + # Create buffer [total_tokens, heads, 132] + buffer = torch.zeros((B * N_CTX, H, 132), dtype=torch.uint8).cuda() + + # Call the function + destindex_copy_indexer_ks(k_fp8, k_scale, mem_index, buffer) + + # Verify results + for i in range(seq_len): + dest_idx = mem_index[i].item() + # Check k_fp8 part + stored_fp8 = buffer[dest_idx, :, :128].view(dtype_fp8) + expected_fp8 = k_fp8[i] + assert torch.allclose(stored_fp8, expected_fp8, atol=1e-6), f"FP8 mismatch at index {i}" + + # Check k_scale part + stored_scale = buffer[dest_idx, :, 128:].view(dtype_scale)[:, :1] + expected_scale = k_scale[i] + assert torch.allclose(stored_scale, expected_scale, atol=1e-6), f"Scale mismatch at index {i}" + + print("All tests passed!") + + +if __name__ == "__main__": + test() diff --git a/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py new file mode 100644 index 000000000..e69de29bb From b9f23de342f6264205e7fa45d13a05030b32ca31 Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 7 Nov 2025 09:16:40 +0000 Subject: [PATCH 04/12] fix --- lightllm/models/deepseek3_2/infer_struct.py | 2 + .../layer_infer/nsa_indexer_layer_inder.py | 17 +++---- .../layer_infer/transformer_layer_infer.py | 50 ++++++++----------- 3 files changed, 31 insertions(+), 38 deletions(-) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index bfdb53fd6..4d77b5f6f 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -9,6 +9,8 @@ def __init__(self): self.page_table_size_1 = None self.ks = None self.ke = None + + self.topk_indices = None return def init_some_extra_state(self, model, input_ids: torch.Tensor): diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 1977c211e..3e5e1c266 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -59,17 +59,16 @@ def ref_fp8_mqa_logits(self, q: torch.Tensor, kv: torch.Tensor, weights: torch.T cost = mask.sum() return logits, cost - def get_indices(self, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight) -> torch.Tensor: - assert self.hidden_states is not None - assert self.q_lora is not None + def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, + infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight) -> torch.Tensor: - q, k = self._get_q_k_bf16(infer_state, layer_weight) + q, k = self._get_q_k_bf16(hidden_states, q_lora, infer_state, layer_weight) q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) self._copy_ks_to_mem_cache(k_fp8, k_scale, infer_state.mem_index, infer_state.mem_manager) - weights = layer_weight.weights_proj_.mm(self.hidden_states) * self.index_n_heads_scale + weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale weights = weights.unsqueeze(-1) * q_scale ks_buffer = infer_state.mem_manager.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] @@ -150,11 +149,11 @@ def _rotate_activation(x: torch.Tensor) -> torch.Tensor: ) == 0, "Hidden size must be a power of 2 for Hadamard transform." return hadamard_transform(x, scale=hidden_size**-0.5) - def _get_q_k_bf16(self, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight): - q = layer_weight.wq_b_proj_.mm(self.q_lora).view(-1, self.index_n_heads, self.index_head_dim) - self.q_lora = None + def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, + infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight): + q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) - k = layer_weight.wk_proj_.mm(self.hidden_states) + k = layer_weight.wk_proj_.mm(hidden_states) k = F.layer_norm( k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps ).type_as(k) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 076d3965c..01514e96a 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -1,5 +1,6 @@ from functools import partial from typing import override +from venv import logger import torch from sgl_kernel.flash_mla import flash_mla_sparse_fwd @@ -24,6 +25,7 @@ def __init__(self, layer_num, network_config, mode=[]): network_config=self.network_config_, mode=mode ) + self.topk_indices = None return @override @@ -35,20 +37,15 @@ def _get_qkv( ) -> torch.Tensor: input = input.view(-1, self.embed_dim_) - if self.q_lora_rank is None: - q = layer_weight.q_weight_.mm(input) - cache_kv = layer_weight.kv_a_proj_with_mqa_.mm(input).view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) - else: - q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 - ) - q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) + q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 + ) + q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) - self.indexer.hidden_states = input - self.indexer.q_lora = q + self.topk_indices = self.indexer.get_indices(input, q, infer_state, layer_weight.indexer_layer_weight) - q = layer_weight.q_b_proj_.mm(q) - cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) + q = layer_weight.q_b_proj_.mm(q) + cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) rmsnorm_forward( @@ -69,11 +66,11 @@ def _get_qkv( @override def _bind_attention(self): super()._bind_attention() - self._context_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._context_attention_flashmla_kernel_with_indexer, self) - self._token_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._token_attention_flashmla_kernel_with_indexer, self) + self._context_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._nsa_context_attention_kernel, self) + self._token_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._nsa_token_attention_kernel, self) pass - def _context_attention_flashmla_kernel_with_indexer( + def _nsa_context_attention_kernel( self, q: torch.Tensor, kv, @@ -85,21 +82,17 @@ def _context_attention_flashmla_kernel_with_indexer( q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) q_all = torch.cat([q_nope, q_rope], dim=-1) - topk_indices = self.indexer.get_indices( - infer_state, - layer_weight.indexer_layer_weight, - ).unsqueeze(1) mla_out, _, _ = flash_mla_sparse_fwd( q=q_all, kv=infer_state.mem_manager.kv_buffer[self.layer_num_], - indices=topk_indices, + indices=self.topk_indices, sm_scale=self.softmax_scale, d_v=self.kv_lora_rank, ) return mla_out - def _token_attention_flashmla_kernel_with_indexer( + def _nsa_token_attention_kernel( self, q, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None ): q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] @@ -107,24 +100,23 @@ def _token_attention_flashmla_kernel_with_indexer( kv = infer_state.mem_manager.kv_buffer[self.layer_num_] k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) - topk_indices = self.indexer.get_indices( - infer_state, - layer_weight.indexer_layer_weight, - ) - o = flash_attn_with_kvcache( + k_descale, v_descale = None, None + o_tensor = flash_attn_with_kvcache( q=q_rope, k_cache=k_rope, v_cache=kv_nope, qv=q_nope, - page_table=topk_indices, + page_table=self.topk_indices, cache_seqlens=infer_state.b_att_seq_len, cu_seqlens_q=infer_state.cu_seqlens_q, cu_seqlens_k_new=infer_state.cu_seqlens_k, max_seqlen_q=infer_state.max_q_seq_len, softmax_scale=self.softmax_scale, causal=True, + window_size=(-1, -1), softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, return_softmax_lse=False, - num_splits=0, # TODO enable_deterministic_inference ) - return o \ No newline at end of file + return o_tensor \ No newline at end of file From ee7100c176efffe3c5040ab2d994fc41b63488c2 Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 7 Nov 2025 10:09:56 +0000 Subject: [PATCH 05/12] need fix --- lightllm/models/deepseek3_2/__init__.py | 0 lightllm/models/deepseek3_2/infer_struct.py | 10 ++++++++-- .../layer_infer/transformer_layer_infer.py | 12 +++--------- lightllm/models/deepseek3_2/model.py | 5 +++++ 4 files changed, 16 insertions(+), 11 deletions(-) create mode 100644 lightllm/models/deepseek3_2/__init__.py diff --git a/lightllm/models/deepseek3_2/__init__.py b/lightllm/models/deepseek3_2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index 4d77b5f6f..b1e61413c 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -9,8 +9,8 @@ def __init__(self): self.page_table_size_1 = None self.ks = None self.ke = None - - self.topk_indices = None + self.nsa_cu_seqlens_k = None + self.index_topk = 2048 return def init_some_extra_state(self, model, input_ids: torch.Tensor): @@ -24,3 +24,9 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): # since b_q_seq_len represents the new tokens being processed if self.b_ready_cache_len is None: self.b_ready_cache_len = self.b_seq_len - self.b_q_seq_len + + self.nsa_cache_seqlens = self.b_att_seq_len.clamp(max=model.index_topk) + assert self.nsa_cache_seqlens.dtype == torch.int32 + self.nsa_cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32), (1, 0) + ) \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 01514e96a..188ab8b4a 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -86,7 +86,7 @@ def _nsa_context_attention_kernel( mla_out, _, _ = flash_mla_sparse_fwd( q=q_all, kv=infer_state.mem_manager.kv_buffer[self.layer_num_], - indices=self.topk_indices, + indices=self.topk_indices.unsqueeze(1), sm_scale=self.softmax_scale, d_v=self.kv_lora_rank, ) @@ -100,23 +100,17 @@ def _nsa_token_attention_kernel( kv = infer_state.mem_manager.kv_buffer[self.layer_num_] k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) - k_descale, v_descale = None, None o_tensor = flash_attn_with_kvcache( q=q_rope, k_cache=k_rope, v_cache=kv_nope, qv=q_nope, page_table=self.topk_indices, - cache_seqlens=infer_state.b_att_seq_len, + cache_seqlens=infer_state.nsa_cache_seqlens, cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, + cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, max_seqlen_q=infer_state.max_q_seq_len, softmax_scale=self.softmax_scale, causal=True, - window_size=(-1, -1), - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, ) return o_tensor \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index c4e56c3c1..ad7f70550 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -16,6 +16,11 @@ class Deepseek3_2TpPartModel(Deepseek2TpPartModel): # infer state class infer_state_class = Deepseek3_2FlashAttentionStateInfo + def __init__(self, kvargs): + super().__init__(kvargs) + self.index_topk = self.config["index_topk"] + return + def _init_mem_manager(self): manager_class = Deepseek3_2MemoryManager if "triton_fp8kv" in self.mode: From 425edb27e8be8e55992d505939090aeb3f8dc194 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 10 Nov 2025 04:07:54 +0000 Subject: [PATCH 06/12] run like deepseek v3 --- lightllm/models/deepseek2/model.py | 2 +- lightllm/models/deepseek3_2/infer_struct.py | 43 +++++- .../layer_infer/nsa_indexer_layer_inder.py | 104 ++++--------- .../layer_infer/transformer_layer_infer.py | 22 +-- lightllm/models/deepseek3_2/model.py | 5 +- .../triton_kernel/fp8_mqa_logits.py | 139 ++++++++++++++++++ 6 files changed, 225 insertions(+), 90 deletions(-) diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index a08147769..ea02dcd7f 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -50,7 +50,7 @@ def __init__(self, model): self.softmax_scale = self.softmax_scale * mscale * mscale -@ModelRegistry(["deepseek_v2", "deepseek_v3"]) +@ModelRegistry(["deepseek_v2", "deepseek_v3", "deepseek_v32"]) class Deepseek2TpPartModel(LlamaTpPartModel): # weight class transformer_weight_class = Deepseek2TransformerLayerWeight diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index b1e61413c..8e5eb0b81 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -1,5 +1,6 @@ import torch from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo +from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager class Deepseek3_2FlashAttentionStateInfo(Deepseek2FlashAttentionStateInfo): @@ -15,6 +16,9 @@ def __init__(self): def init_some_extra_state(self, model, input_ids: torch.Tensor): super().init_some_extra_state(model, input_ids) + assert isinstance(self.mem_manager, Deepseek3_2MemoryManager) + self.indexer_ks_mem_manager = self.mem_manager.indexer_ks_mem_manager + # Ensure b_ready_cache_len is set for both prefill and decode modes if self.is_prefill: # b_ready_cache_len is already set in basemodel.py for prefill @@ -24,9 +28,42 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): # since b_q_seq_len represents the new tokens being processed if self.b_ready_cache_len is None: self.b_ready_cache_len = self.b_seq_len - self.b_q_seq_len - - self.nsa_cache_seqlens = self.b_att_seq_len.clamp(max=model.index_topk) + + self.nsa_cache_seqlens = self.b_att_seq_len.clamp(max=self.index_topk) assert self.nsa_cache_seqlens.dtype == torch.int32 self.nsa_cu_seqlens_k = torch.nn.functional.pad( torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32), (1, 0) - ) \ No newline at end of file + ) + + # Pre-compute NSA indexer indexing structures + self._init_nsa_indexing_structures() + + def _init_nsa_indexing_structures(self): + """Pre-compute ks, ke, lengths, and page_table_size_1 for NSA indexer""" + mem_index_list = [] + ks_list = [] + ke_list = [] + lengths_list = [] + offset = 0 + num_seq_len = self.b_req_idx.shape[0] + self.page_table_size_1 = torch.zeros((num_seq_len, self.b_seq_len.max()), dtype=torch.int, device='cuda') + + for i in range(num_seq_len): + seq_len = self.b_seq_len[i] + q_seq_len = self.b_q_seq_len[i] + mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] + mem_index_list.append(mem_index) + self.page_table_size_1[i, :seq_len] = mem_index + ks = torch.zeros(q_seq_len, dtype=torch.int, device='cuda') + offset + ke = torch.arange(q_seq_len, dtype=torch.int, device='cuda') + offset + 1 + ks_list.append(ks) + ke_list.append(ke) + lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda')) + offset += seq_len + + self.mem_index = torch.cat(mem_index_list, dim=0) + # ks : [seq_len_q] 标志kv的起始位置 + # ke : [seq_len_q] 标志kv的结束位置 + self.ks = torch.cat(ks_list, dim=0) + self.ke = torch.cat(ke_list, dim=0) + self.lengths = torch.cat(lengths_list, dim=0) \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 3e5e1c266..d7444e918 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -10,7 +10,9 @@ from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks -# from lightllm.models.deepseek3_2.triton_kernel.fp8_mqa_logits import fp8_mqa_logits +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) class NSAIndexerInfer(BaseLayerInfer): def __init__(self, layer_idx, network_config, mode=[]): @@ -66,70 +68,37 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) - self._copy_ks_to_mem_cache(k_fp8, k_scale, infer_state.mem_index, infer_state.mem_manager) + destindex_copy_indexer_ks( + k_fp8.unsqueeze(1), + k_scale.unsqueeze(1), + infer_state.mem_index, + infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] + ) weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale weights = weights.unsqueeze(-1) * q_scale - ks_buffer = infer_state.mem_manager.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] - - k_fp8_list = [] - k_scale_list = [] - ks_list = [] - ke_list = [] - offset = 0 - for i in range(infer_state.batch_size): - q_len = infer_state.b_q_seq_len[i] - cache_len = infer_state.b_ready_cache_len[i] - mem_indexes = infer_state.req_manager.req_to_token_indexs[infer_state.b_req_idx[i], :cache_len+q_len] - k_fp8 = ks_buffer[mem_indexes, 0, :128].view(torch.float8_e4m3fn).contiguous() - k_scale = ks_buffer[mem_indexes, 0, 128:].view(torch.float32).contiguous() - ks = torch.full((q_len,), offset, dtype=torch.int32, device="cuda") - ke = ks + torch.arange(q_len, dtype=torch.int32, device="cuda") + 1 - k_fp8_list.append(k_fp8) - k_scale_list.append(k_scale) - ks_list.append(ks) - ke_list.append(ke) - offset += q_len - - k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn) - k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1) - kv_fp8 = (k_fp8, k_scale) - ks = torch.cat(ks_list, dim=0) - ke = torch.cat(ke_list, dim=0) - - logits = deep_gemm.fp8_mqa_logits( - q_fp8, - kv_fp8, - weights.squeeze(-1), - ks, - ke, - clean_logits=False, - ) - - return self.get_topk(logits, infer_state) - - def get_topk(self, logits, infer_state: Deepseek3_2FlashAttentionStateInfo): - topk_indices_list = [] - offset = 0 - - for i in range(infer_state.batch_size): - q_len = infer_state.b_q_seq_len[i] - cache_len = infer_state.b_ready_cache_len[i] - end_pos = q_len + cache_len - # Slice logits for this batch (both query and sequence dimensions) - batch_logits = logits[offset:offset + q_len, :end_pos] - topk_indices = batch_logits.topk(min(self.index_topk, end_pos), dim=-1)[1] - mem_indexes = infer_state.req_manager.req_to_token_indexs[infer_state.b_req_idx[i], :cache_len+q_len] - indices = torch.full((q_len, self.index_topk), -1, dtype=torch.int32, device="cuda") - for j in range(q_len): - indices[j, :topk_indices[j].shape[0]] = mem_indexes[topk_indices[j]] - topk_indices_list.append(indices) - offset += q_len + # Use pre-computed indexing structures from infer_state + mem_index = infer_state.mem_index + ks = infer_state.ks + ke = infer_state.ke + lengths = infer_state.lengths + page_table_1 = infer_state.page_table_size_1 - topk_indices_ = torch.cat(topk_indices_list, dim=0) + # TODO + k_fp8_ = infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_][mem_index, :, :128].view(torch.float8_e4m3fn).squeeze(1).contiguous() + k_scale_ = infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_][mem_index, :, 128:].view(torch.float32)[:, 0, 0].contiguous() - return topk_indices_ + logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) + + # 返回 : [seq_q_len, topk] 无效的位置使用-1填充 + return fast_topk_transform_fused( + score=logits, # [seq_len_q, seq_len_kv] + lengths=lengths, # [seq_len_q] + page_table_size_1=page_table_1, # [seq_len_q, max(lengths)] 无效的使用0填充 + cu_seqlens_q=infer_state.cu_seqlens_q, # [seq_len_q + 1] + topk=self.index_topk, + ) def get_k_float32_from_buffer(self, buffer: torch.Tensor): @@ -152,8 +121,9 @@ def _rotate_activation(x: torch.Tensor) -> torch.Tensor: def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight): q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) - k = layer_weight.wk_proj_.mm(hidden_states) + + # TODO k = F.layer_norm( k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps ).type_as(k) @@ -168,17 +138,3 @@ def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, q = self._rotate_activation(q) k = self._rotate_activation(k) return q, k - - def _copy_ks_to_mem_cache(self, k_fp8, k_scale, mem_index, mem_manager: Deepseek3_2MemoryManager): - # k_fp8 : [seq_len, 128] torch.fp8_e4m3 - # k_scale : [seq_len, 1] torch.float32 - # mem_index : [seq_len] torch.int32 - # buffer : [10000000, 1, 132] torch.uint8 - buffer = mem_manager.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] - destindex_copy_indexer_ks( - k_fp8.unsqueeze(1), # Add head dimension: [seq_len, 1, 128] - k_scale.unsqueeze(1), # Add head dimension: [seq_len, 1, 1] - mem_index, - buffer - ) - return \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 188ab8b4a..ed351312f 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -82,10 +82,9 @@ def _nsa_context_attention_kernel( q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) q_all = torch.cat([q_nope, q_rope], dim=-1) - mla_out, _, _ = flash_mla_sparse_fwd( - q=q_all, - kv=infer_state.mem_manager.kv_buffer[self.layer_num_], + q=q_all, # [seq_len_q, q_num_head, qk_dim] + kv=infer_state.mem_manager.kv_buffer[self.layer_num_], # [size, 1, qk_dim] indices=self.topk_indices.unsqueeze(1), sm_scale=self.softmax_scale, d_v=self.kv_lora_rank, @@ -100,15 +99,16 @@ def _nsa_token_attention_kernel( kv = infer_state.mem_manager.kv_buffer[self.layer_num_] k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) + o_tensor = flash_attn_with_kvcache( - q=q_rope, - k_cache=k_rope, - v_cache=kv_nope, - qv=q_nope, - page_table=self.topk_indices, - cache_seqlens=infer_state.nsa_cache_seqlens, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, + q=q_rope, # (q_seqlen, nheads, qk_headdim) + k_cache=k_rope, # (kv_size, 1, 1, qk_head_dim) + v_cache=kv_nope, # (kv_size, 1, 1, kv_lora_rank) + qv=q_nope, # (q_seqlen, nheads, kv_lora_rank) + page_table=self.topk_indices, # (q_seqlen, max_seq_len) + cache_seqlens=infer_state.nsa_cache_seqlens, # (q_seqlen) # 表示当前kv长度,用于读取page_table. + cu_seqlens_q=infer_state.cu_seqlens_q, # (batch_size+1) [0,1] + cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, #(batch_size+1) [0,9] max_seqlen_q=infer_state.max_q_seq_len, softmax_scale=self.softmax_scale, causal=True, diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index ad7f70550..b80094488 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -5,7 +5,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager, Deepseek3_2FP8KVMemoryManager -@ModelRegistry(["deepseek_v32"]) +# @ModelRegistry(["deepseek_v32"]) class Deepseek3_2TpPartModel(Deepseek2TpPartModel): # weight class transformer_weight_class = Deepseek3_2TransformerLayerWeight @@ -21,6 +21,9 @@ def __init__(self, kvargs): self.index_topk = self.config["index_topk"] return + def _init_inferstate_cls(self): + self.infer_state_class = Deepseek3_2FlashAttentionStateInfo + def _init_mem_manager(self): manager_class = Deepseek3_2MemoryManager if "triton_fp8kv" in self.mode: diff --git a/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py index e69de29bb..2fc92662a 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py +++ b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py @@ -0,0 +1,139 @@ +import triton +import triton.language as tl +import torch + + +@triton.jit +def _fp8_paged_mqa_logits_kernel( + Q_ptr, KV_ptr, KVScale_ptr, Weights_ptr, MemIndex_ptr, + CuSeqlenKs_ptr, CuSeqlenKe_ptr, Output_ptr, + seq_len, seq_len_kv, num_heads, head_dim, + stride_q_seq, stride_q_head, stride_q_dim, + stride_kv_pool, stride_kv_dim, + stride_w_seq, stride_w_head, + stride_o_seq, stride_o_kv, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + # Compute the range of seq positions this block handles + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + + # Offset arrays for this block + offs_m = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_n = start_n + tl.arange(0, BLOCK_SIZE_N) + + # Initialize accumulator for logits + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Create masks + mask_m = offs_m < seq_len + mask_n = offs_n < seq_len_kv + + # Load mem_indices for the KV positions + mem_indices = tl.load(MemIndex_ptr + offs_n, mask=mask_n, other=0) + + # Load scales for K + scales = tl.load(KVScale_ptr + mem_indices, mask=mask_n, other=1.0) + + # Loop over all heads + for h in range(num_heads): + # Load weights for this head + weights = tl.load(Weights_ptr + offs_m * stride_w_seq + h * stride_w_head, + mask=mask_m, other=0.0) + + # Initialize score accumulator for this head + score = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Loop over head_dim in blocks + for d_block in range(tl.cdiv(head_dim, BLOCK_SIZE_D)): + d_start = d_block * BLOCK_SIZE_D + offs_d = d_start + tl.arange(0, BLOCK_SIZE_D) + mask_d = offs_d < head_dim + + # Load Q for this head and dimension block + # Q shape: (seq_len, num_heads, head_dim) + q_ptrs = Q_ptr + offs_m[:, None] * stride_q_seq + h * stride_q_head + offs_d[None, :] * stride_q_dim + mask_q = (offs_m[:, None] < seq_len) & mask_d[None, :] + q = tl.load(q_ptrs, mask=mask_q, other=0.0).to(tl.float32) + + # Load K for this dimension block + # KV shape: (pool_size, head_dim) as FP8 data + k_ptrs = KV_ptr + mem_indices[:, None] * stride_kv_pool + offs_d[None, :] * stride_kv_dim + mask_k = mask_n[:, None] & mask_d[None, :] + k = tl.load(k_ptrs, mask=mask_k, other=0.0).to(tl.float32) + + # Apply scale to K (scale is per-row of K) + k = k * scales[:, None] + + # Compute partial dot product: q @ k.T + # q: (BLOCK_SIZE_M, BLOCK_SIZE_D), k: (BLOCK_SIZE_N, BLOCK_SIZE_D) + # score: (BLOCK_SIZE_M, BLOCK_SIZE_N) + score += tl.dot(q, tl.trans(k)) + + # Apply ReLU to score + score = tl.maximum(score, 0.0) + + # Multiply by weights and accumulate to logits + logits += score * weights[:, None] + + # Apply mask based on cu_seqlen_ks and cu_seqlen_ke + mask_ks = tl.load(CuSeqlenKs_ptr + offs_m, mask=mask_m, other=0) + mask_ke = tl.load(CuSeqlenKe_ptr + offs_m, mask=mask_m, other=seq_len_kv) + + mask_lo = offs_n[None, :] >= mask_ks[:, None] + mask_hi = offs_n[None, :] < mask_ke[:, None] + mask_valid = mask_lo & mask_hi & mask_m[:, None] & mask_n[None, :] + + # Apply mask (-inf for masked positions) + logits = tl.where(mask_valid, logits, float('-inf')) + + # Store output + out_ptrs = Output_ptr + offs_m[:, None] * stride_o_seq + offs_n[None, :] * stride_o_kv + mask_out = (offs_m[:, None] < seq_len) & (offs_n[None, :] < seq_len_kv) + tl.store(out_ptrs, logits, mask=mask_out) + + +def fp8_paged_mqa_logits( + q: torch.Tensor, + kv: torch.Tensor, + kv_scale: torch.Tensor, + weights: torch.Tensor, + mem_index: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + out: torch.Tensor = None +) -> torch.Tensor: + seq_len, num_heads, head_dim = q.shape + seq_len_kv = mem_index.shape[0] + + if out is None: + output = torch.empty((seq_len, seq_len_kv), device=q.device, dtype=torch.float32) + else: + output = out + + BLOCK_SIZE_M = 16 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_D = 128 + + grid = (triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(seq_len_kv, BLOCK_SIZE_N)) + + _fp8_paged_mqa_logits_kernel[grid]( + q, kv, kv_scale, weights, mem_index, + cu_seqlen_ks, cu_seqlen_ke, output, + seq_len, seq_len_kv, num_heads, head_dim, + q.stride(0), q.stride(1), q.stride(2), + kv.stride(0), kv.stride(1), + weights.stride(0), weights.stride(1), + output.stride(0), output.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_D=BLOCK_SIZE_D, + ) + + return output \ No newline at end of file From d1a773d22cbd21b915fef6b1864d59e0ba9ba37a Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 10 Nov 2025 04:41:25 +0000 Subject: [PATCH 07/12] fix --- .../layer_infer/nsa_indexer_layer_inder.py | 12 +- .../triton_kernel/extract_indexer_ks.py | 156 ++++++++++++++++++ 2 files changed, 160 insertions(+), 8 deletions(-) create mode 100644 lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index d7444e918..173196bf4 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -10,6 +10,8 @@ from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks +from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks +from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -78,16 +80,13 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale weights = weights.unsqueeze(-1) * q_scale - # Use pre-computed indexing structures from infer_state mem_index = infer_state.mem_index ks = infer_state.ks ke = infer_state.ke lengths = infer_state.lengths page_table_1 = infer_state.page_table_size_1 - # TODO - k_fp8_ = infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_][mem_index, :, :128].view(torch.float8_e4m3fn).squeeze(1).contiguous() - k_scale_ = infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_][mem_index, :, 128:].view(torch.float32)[:, 0, 0].contiguous() + k_fp8_, k_scale_ = extract_indexer_ks(infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_], mem_index) logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) @@ -123,10 +122,7 @@ def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) k = layer_weight.wk_proj_.mm(hidden_states) - # TODO - k = F.layer_norm( - k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps - ).type_as(k) + k = layernorm_forward(k, layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps) rotary_emb_fwd( q[:, :, : self.qk_rope_head_dim], diff --git a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py new file mode 100644 index 000000000..e97454ba2 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py @@ -0,0 +1,156 @@ +import torch +import triton +import triton.language as tl +import numpy + + +@triton.jit +def _fwd_kernel_extract_indexer_ks( + buffer_fp8, + buffer_scale, + mem_index, + k_fp8_out, + k_scale_out, + stride_buffer_fp8_bs, + stride_buffer_fp8_h, + stride_buffer_fp8_d, + stride_buffer_scale_bs, + stride_buffer_scale_h, + stride_buffer_scale_d, + stride_k_fp8_out_bs, + stride_k_fp8_out_d, + stride_k_scale_out_bs, + BLOCK_DMODEL: tl.constexpr, +): + cur_index = tl.program_id(0) + + # Load the memory index + mem_idx = tl.load(mem_index + cur_index).to(tl.int64) + + # Load k_fp8 data from buffer_fp8[mem_idx, 0, :] + offs_d = tl.arange(0, BLOCK_DMODEL) + k_fp8_ptrs = buffer_fp8 + mem_idx * stride_buffer_fp8_bs + 0 * stride_buffer_fp8_h + offs_d * stride_buffer_fp8_d + k_fp8_data = tl.load(k_fp8_ptrs) + + # Load k_scale data from buffer_scale[mem_idx, 0, 0] + k_scale_ptr = buffer_scale + mem_idx * stride_buffer_scale_bs + 0 * stride_buffer_scale_h + 0 * stride_buffer_scale_d + k_scale_data = tl.load(k_scale_ptr) + + # Store k_fp8 output + k_fp8_out_ptrs = k_fp8_out + cur_index * stride_k_fp8_out_bs + offs_d * stride_k_fp8_out_d + tl.store(k_fp8_out_ptrs, k_fp8_data) + + # Store k_scale output + k_scale_out_ptr = k_scale_out + cur_index * stride_k_scale_out_bs + tl.store(k_scale_out_ptr, k_scale_data) + + +@torch.no_grad() +def extract_indexer_ks(buffer, mem_index): + """ + Extract k_fp8 and k_scale from the indexer memory buffer using Triton kernel. + + Args: + buffer: Memory buffer of shape [total_tokens, heads, 132] with dtype uint8 + mem_index: Indices tensor of shape [seq_len] with dtype int32/int64 + + Returns: + k_fp8: Tensor of shape [seq_len, 128] with dtype float8_e4m3fn + k_scale: Tensor of shape [seq_len] with dtype float32 + """ + seq_len = mem_index.shape[0] + assert buffer.shape[2] == 132, f"buffer dim should be 132, got {buffer.shape[2]}" + + # Reinterpret buffer as the appropriate types for Triton + buffer_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn) + buffer_scale = buffer[:, :, 128:132].view(torch.float32)[:, :, :1] + + # Prepare output tensors + k_fp8_out = torch.empty((seq_len, 128), dtype=torch.float8_e4m3fn, device=buffer.device) + k_scale_out = torch.empty((seq_len,), dtype=torch.float32, device=buffer.device) + + BLOCK_DMODEL = 128 + grid = (seq_len,) + num_warps = 1 + + _fwd_kernel_extract_indexer_ks[grid]( + buffer_fp8, + buffer_scale, + mem_index, + k_fp8_out, + k_scale_out, + buffer_fp8.stride(0), + buffer_fp8.stride(1), + buffer_fp8.stride(2), + buffer_scale.stride(0), + buffer_scale.stride(1), + buffer_scale.stride(2), + k_fp8_out.stride(0), + k_fp8_out.stride(1), + k_scale_out.stride(0), + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=num_warps, + num_stages=1, + ) + + return k_fp8_out, k_scale_out + + +def test(): + # Test parameters similar to the usage in nsa_indexer_layer_inder.py + B, N_CTX, H = 4, 1024, 1 # batch_size, seq_len, heads (always 1 for this) + seq_len = 50 # number of tokens to extract + dtype_fp8 = torch.float8_e4m3fn + dtype_scale = torch.float32 + + # Create test buffer [total_tokens, heads, 132] as uint8 + buffer = torch.zeros((B * N_CTX, H, 132), dtype=torch.uint8).cuda() + + # Fill with test data - simulate what destindex_copy_indexer_ks does + test_indices = torch.randint(0, B * N_CTX, (seq_len,), dtype=torch.int32).cuda() + # Generate fp8 data by converting from float32 + test_k_fp8_fp32 = torch.randn((seq_len, 128), dtype=torch.float32).cuda() + test_k_fp8 = test_k_fp8_fp32.to(dtype_fp8) + test_k_scale = torch.randn((seq_len,), dtype=dtype_scale).cuda() + + # Manually populate buffer as destindex_copy_indexer_ks would + for i in range(seq_len): + dest_idx = test_indices[i].item() + # Store fp8 data + buffer[dest_idx, 0, :128] = test_k_fp8[i].view(torch.uint8) + # Store scale data (4 bytes) - need to convert float32 to bytes + scale_bytes = test_k_scale[i].cpu().numpy().tobytes() + scale_bytes_np = numpy.frombuffer(scale_bytes, dtype=numpy.uint8) + buffer[dest_idx, 0, 128:132] = torch.from_numpy(scale_bytes_np).to(buffer.device) + + # Call our extraction function + extracted_fp8, extracted_scale = extract_indexer_ks(buffer, test_indices) + + # Verify results + print(f"Original k_fp8 shape: {test_k_fp8.shape}, dtype: {test_k_fp8.dtype}") + print(f"Extracted k_fp8 shape: {extracted_fp8.shape}, dtype: {extracted_fp8.dtype}") + print(f"Original k_scale shape: {test_k_scale.shape}, dtype: {test_k_scale.dtype}") + print(f"Extracted k_scale shape: {extracted_scale.shape}, dtype: {extracted_scale.dtype}") + + # Check if extraction matches (convert fp8 to float32 for comparison) + # Use higher tolerance for fp8 due to quantization precision + fp8_match = torch.allclose(test_k_fp8_fp32, extracted_fp8.float(), atol=0.1, rtol=0.1) + scale_match = torch.allclose(test_k_scale, extracted_scale, atol=1e-6) + + print(f"FP8 data matches: {fp8_match}") + print(f"Scale data matches: {scale_match}") + + if fp8_match and scale_match: + print("All tests passed!") + else: + print("Test failed!") + if not fp8_match: + print("First few fp8 values:") + print(f"Original: {test_k_fp8_fp32[0, :5]}") + print(f"Extracted: {extracted_fp8.float()[0, :5]}") + if not scale_match: + print(f"Max scale diff: {torch.max(torch.abs(test_k_scale - extracted_scale))}") + + +if __name__ == "__main__": + test() From eb4b957c03e6efb33544d786492936acd49eb616 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 10 Nov 2025 04:49:38 +0000 Subject: [PATCH 08/12] fix --- lightllm/common/basemodel/basemodel.py | 2 +- lightllm/models/deepseek2/model.py | 2 +- lightllm/models/deepseek3_2/model.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 5be45c6b1..77ca299b2 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -110,7 +110,7 @@ def __init__(self, kvargs): self._init_some_value() self._init_custom() self._init_inferstate_cls() - # self._autotune_warmup() + self._autotune_warmup() self._init_padded_req() # wait必须在init cudagraph 之前,避免错误捕获 self._wait_other_modules_ready() diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index ea02dcd7f..a08147769 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -50,7 +50,7 @@ def __init__(self, model): self.softmax_scale = self.softmax_scale * mscale * mscale -@ModelRegistry(["deepseek_v2", "deepseek_v3", "deepseek_v32"]) +@ModelRegistry(["deepseek_v2", "deepseek_v3"]) class Deepseek2TpPartModel(LlamaTpPartModel): # weight class transformer_weight_class = Deepseek2TransformerLayerWeight diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index b80094488..8f1ba85cf 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -5,7 +5,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager, Deepseek3_2FP8KVMemoryManager -# @ModelRegistry(["deepseek_v32"]) +@ModelRegistry(["deepseek_v32"]) class Deepseek3_2TpPartModel(Deepseek2TpPartModel): # weight class transformer_weight_class = Deepseek3_2TransformerLayerWeight From 33d81f3c04d967c4f1844de162109495b8b5fc6e Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 10 Nov 2025 04:51:11 +0000 Subject: [PATCH 09/12] fix --- .../models/deepseek3_2/layer_infer/transformer_layer_infer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index ed351312f..5fc33d5aa 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -1,6 +1,5 @@ from functools import partial from typing import override -from venv import logger import torch from sgl_kernel.flash_mla import flash_mla_sparse_fwd From 3b3a20465381125eac6f47b8b8e439630cdf7ae3 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 10 Nov 2025 04:53:59 +0000 Subject: [PATCH 10/12] fix --- lightllm/models/deepseek3_2/infer_struct.py | 2 -- .../layer_infer/nsa_indexer_layer_inder.py | 9 ++++----- .../layer_infer/transformer_layer_infer.py | 20 +++++++++---------- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index 8e5eb0b81..e955c3bbd 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -62,8 +62,6 @@ def _init_nsa_indexing_structures(self): offset += seq_len self.mem_index = torch.cat(mem_index_list, dim=0) - # ks : [seq_len_q] 标志kv的起始位置 - # ke : [seq_len_q] 标志kv的结束位置 self.ks = torch.cat(ks_list, dim=0) self.ke = torch.cat(ke_list, dim=0) self.lengths = torch.cat(lengths_list, dim=0) \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 173196bf4..d5032e72f 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -90,12 +90,11 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) - # 返回 : [seq_q_len, topk] 无效的位置使用-1填充 return fast_topk_transform_fused( - score=logits, # [seq_len_q, seq_len_kv] - lengths=lengths, # [seq_len_q] - page_table_size_1=page_table_1, # [seq_len_q, max(lengths)] 无效的使用0填充 - cu_seqlens_q=infer_state.cu_seqlens_q, # [seq_len_q + 1] + score=logits, + lengths=lengths, + page_table_size_1=page_table_1, + cu_seqlens_q=infer_state.cu_seqlens_q, topk=self.index_topk, ) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 5fc33d5aa..5b550ab09 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -82,8 +82,8 @@ def _nsa_context_attention_kernel( q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) q_all = torch.cat([q_nope, q_rope], dim=-1) mla_out, _, _ = flash_mla_sparse_fwd( - q=q_all, # [seq_len_q, q_num_head, qk_dim] - kv=infer_state.mem_manager.kv_buffer[self.layer_num_], # [size, 1, qk_dim] + q=q_all, + kv=infer_state.mem_manager.kv_buffer[self.layer_num_], indices=self.topk_indices.unsqueeze(1), sm_scale=self.softmax_scale, d_v=self.kv_lora_rank, @@ -100,14 +100,14 @@ def _nsa_token_attention_kernel( kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) o_tensor = flash_attn_with_kvcache( - q=q_rope, # (q_seqlen, nheads, qk_headdim) - k_cache=k_rope, # (kv_size, 1, 1, qk_head_dim) - v_cache=kv_nope, # (kv_size, 1, 1, kv_lora_rank) - qv=q_nope, # (q_seqlen, nheads, kv_lora_rank) - page_table=self.topk_indices, # (q_seqlen, max_seq_len) - cache_seqlens=infer_state.nsa_cache_seqlens, # (q_seqlen) # 表示当前kv长度,用于读取page_table. - cu_seqlens_q=infer_state.cu_seqlens_q, # (batch_size+1) [0,1] - cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, #(batch_size+1) [0,9] + q=q_rope, + k_cache=k_rope, + v_cache=kv_nope, + qv=q_nope, + page_table=self.topk_indices, + cache_seqlens=infer_state.nsa_cache_seqlens, + cu_seqlens_q=infer_state.cu_seqlens_q, + cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, max_seqlen_q=infer_state.max_q_seq_len, softmax_scale=self.softmax_scale, causal=True, From 2ba2189d9eeebe6bfac7ceeb83c98b3f213ef45f Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 10 Nov 2025 13:57:17 +0000 Subject: [PATCH 11/12] can run without cudagraph --- .../meta_weights/fused_moe_weight_ep.py | 3 - lightllm/models/deepseek3_2/infer_struct.py | 6 +- .../layer_infer/nsa_indexer_layer_inder.py | 24 +- .../layer_infer/transformer_layer_infer.py | 6 +- .../destindex_copy_indexer_ks.py | 354 ++++++++++----- .../triton_kernel/extract_indexer_ks.py | 409 ++++++++++++------ 6 files changed, 547 insertions(+), 255 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py index 87a7b361e..d64b17bf9 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py @@ -67,9 +67,6 @@ def __init__( self.global_rank_ = get_global_rank() self.redundancy_expert_num = get_redundancy_expert_num() self.redundancy_expert_ids = get_redundancy_expert_ids(layer_num) - logger.info( - f"global_rank {self.global_rank_} layerindex {layer_num} redundancy_expertids: {self.redundancy_expert_ids}" - ) self.redundancy_expert_ids_tensor = torch.tensor(self.redundancy_expert_ids, dtype=torch.int64, device="cuda") self.routed_expert_counter_tensor = torch.zeros((self.n_routed_experts,), dtype=torch.int64, device="cuda") self.total_expert_num_contain_redundancy = ( diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index e955c3bbd..c122c6a7e 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -40,7 +40,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): def _init_nsa_indexing_structures(self): """Pre-compute ks, ke, lengths, and page_table_size_1 for NSA indexer""" - mem_index_list = [] + req_all_mem_index_list = [] ks_list = [] ke_list = [] lengths_list = [] @@ -52,7 +52,7 @@ def _init_nsa_indexing_structures(self): seq_len = self.b_seq_len[i] q_seq_len = self.b_q_seq_len[i] mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] - mem_index_list.append(mem_index) + req_all_mem_index_list.append(mem_index) self.page_table_size_1[i, :seq_len] = mem_index ks = torch.zeros(q_seq_len, dtype=torch.int, device='cuda') + offset ke = torch.arange(q_seq_len, dtype=torch.int, device='cuda') + offset + 1 @@ -61,7 +61,7 @@ def _init_nsa_indexing_structures(self): lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda')) offset += seq_len - self.mem_index = torch.cat(mem_index_list, dim=0) + self.req_all_mem_index = torch.cat(req_all_mem_index_list, dim=0) self.ks = torch.cat(ks_list, dim=0) self.ke = torch.cat(ke_list, dim=0) self.lengths = torch.cat(lengths_list, dim=0) \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index d5032e72f..df045dd2d 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -71,8 +71,8 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) destindex_copy_indexer_ks( - k_fp8.unsqueeze(1), - k_scale.unsqueeze(1), + k_fp8, + k_scale, infer_state.mem_index, infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] ) @@ -80,13 +80,16 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale weights = weights.unsqueeze(-1) * q_scale - mem_index = infer_state.mem_index ks = infer_state.ks ke = infer_state.ke lengths = infer_state.lengths page_table_1 = infer_state.page_table_size_1 - k_fp8_, k_scale_ = extract_indexer_ks(infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_], mem_index) + # Use efficient Triton kernel to extract FP8 keys and scales from buffer + k_fp8_, k_scale_ = extract_indexer_ks( + infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_], + infer_state.req_all_mem_index + ) logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) @@ -99,12 +102,6 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, ) - def get_k_float32_from_buffer(self, buffer: torch.Tensor): - k_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn) - k_scale = buffer[:, :, 128:].view(torch.float32)[:, :, :1] - k_float32 = k_fp8.float() * k_scale - return k_float32 - @staticmethod def _rotate_activation(x: torch.Tensor) -> torch.Tensor: assert x.dtype == torch.bfloat16 @@ -121,8 +118,11 @@ def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) k = layer_weight.wk_proj_.mm(hidden_states) - k = layernorm_forward(k, layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps) - + # TODO + k = F.layer_norm( + k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps + ).type_as(k) + rotary_emb_fwd( q[:, :, : self.qk_rope_head_dim], k[:, None, : self.qk_rope_head_dim], diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 5b550ab09..df5220427 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -64,7 +64,11 @@ def _get_qkv( @override def _bind_attention(self): - super()._bind_attention() + if "triton_fp8kv" in self.mode: + self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_fp8, self) + else: + self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) + self._context_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._nsa_context_attention_kernel, self) self._token_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._nsa_token_attention_kernel, self) pass diff --git a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py index a098795fb..46095bfb7 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py +++ b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py @@ -6,132 +6,270 @@ @triton.jit def _fwd_kernel_destindex_copy_indexer_ks( - k_fp8, - k_scale, - mem_index, - buffer_fp8, - buffer_scale, - stride_k_fp8_bs, - stride_k_fp8_h, - stride_k_fp8_d, - stride_k_scale_bs, - stride_k_scale_h, - stride_k_scale_d, - stride_buffer_fp8_bs, - stride_buffer_fp8_h, - stride_buffer_fp8_d, - stride_buffer_scale_bs, - stride_buffer_scale_h, - stride_buffer_scale_d, - head_num, + K_fp8, + K_scale, + DestLoc, + O_buffer, + stride_k_bs, + stride_k_d, + stride_scale_bs, + stride_scale_d, + stride_o_bs, + stride_o_h, + stride_o_d, BLOCK_DMODEL: tl.constexpr, - BLOCK_HEAD: tl.constexpr, ): + """ + Triton kernel to copy FP8 K values and their scales to an indexed output buffer. + + This kernel reads FP8 key values (128 dims) and their float32 scale values, + then writes them to a compact buffer format where each entry contains: + - Bytes 0-127: FP8 key values (128 bytes) + - Bytes 128-131: Float32 scale (4 bytes) + + The destination location for each source element is specified by DestLoc. + """ cur_index = tl.program_id(0) - offs_h = tl.arange(0, BLOCK_HEAD) offs_d = tl.arange(0, BLOCK_DMODEL) - - dest_index = tl.load(mem_index + cur_index).to(tl.int64) - - # Load k_fp8 data - k_fp8_ptrs = k_fp8 + cur_index * stride_k_fp8_bs + stride_k_fp8_h * offs_h[:, None] + stride_k_fp8_d * offs_d[None, :] - k_fp8_data = tl.load(k_fp8_ptrs, mask=offs_h[:, None] < head_num, other=0.0) - - # Load k_scale data - k_scale_ptrs = k_scale + cur_index * stride_k_scale_bs + stride_k_scale_h * offs_h[:, None] + stride_k_scale_d * tl.arange(0, 1)[None, :] - k_scale_data = tl.load(k_scale_ptrs, mask=offs_h[:, None] < head_num, other=0.0) - - # Store k_fp8 to buffer_fp8 - buffer_fp8_ptrs = buffer_fp8 + dest_index * stride_buffer_fp8_bs + stride_buffer_fp8_h * offs_h[:, None] + stride_buffer_fp8_d * offs_d[None, :] - tl.store(buffer_fp8_ptrs, k_fp8_data, mask=offs_h[:, None] < head_num) - - # Store k_scale to buffer_scale - buffer_scale_ptrs = buffer_scale + dest_index * stride_buffer_scale_bs + stride_buffer_scale_h * offs_h[:, None] + stride_buffer_scale_d * tl.arange(0, 1)[None, :] - tl.store(buffer_scale_ptrs, k_scale_data, mask=offs_h[:, None] < head_num) + + # Load destination index for this thread + dest_index = tl.load(DestLoc + cur_index).to(tl.int64) + + # Load K_fp8 (128 values) and K_scale (1 value) from source + k_fp8_ptrs = K_fp8 + cur_index * stride_k_bs + stride_k_d * offs_d + k_fp8 = tl.load(k_fp8_ptrs) + + k_scale = tl.load(K_scale + cur_index * stride_scale_bs) + + # Store K_fp8 to O_buffer[:, 0, :128] + # Convert fp8 to uint8 through bitcast for storage in uint8 buffer + o_k_ptrs = O_buffer + dest_index * stride_o_bs + stride_o_d * offs_d + k_fp8_as_uint8 = k_fp8.to(tl.uint8, bitcast=True) + tl.store(o_k_ptrs, k_fp8_as_uint8) + + # Store K_scale to O_buffer[:, 0, 128:132] (4 bytes for float32) + # Convert float32 scale to 4 uint8 bytes using bitcast and bit manipulation + o_scale_ptr = O_buffer + dest_index * stride_o_bs + BLOCK_DMODEL * stride_o_d + scale_as_uint32 = k_scale.to(tl.float32, bitcast=True).to(tl.uint32, bitcast=True) + + # Store each byte of the float32 scale (little-endian) + for i in range(4): + byte_val = ((scale_as_uint32 >> (i * 8)) & 0xFF).to(tl.uint8) + tl.store(o_scale_ptr + i * stride_o_d, byte_val) + + return @torch.no_grad() -def destindex_copy_indexer_ks(k_fp8, k_scale, mem_index, buffer): - seq_len = mem_index.shape[0] - head_num = k_fp8.shape[1] - k_fp8_dim = k_fp8.shape[2] # Should be 128 for float8 - k_scale_dim = k_scale.shape[2] # Should be 1 +def destindex_copy_indexer_ks(K_fp8: torch.Tensor, K_scale: torch.Tensor, DestLoc: torch.Tensor, O_buffer: torch.Tensor): + """ + Copy FP8-quantized key values and their scales to indexed locations in a buffer. + + This function is used in the DeepSeek-V3.2 NSA (Neighbor-aware Sparse Attention) + mechanism to store compressed key representations in a memory buffer. Each key + is stored with its FP8 representation (128 bytes) followed by its float32 scale + (4 bytes), for a total of 132 bytes per key. + + Args: + K_fp8: [q_seq_len, 128] torch.fp8_e4m3fn + FP8-quantized key values + K_scale: [q_seq_len, 1] torch.float32 + Quantization scales for each key + DestLoc: [q_seq_len] torch.int32 + Destination indices in the output buffer + O_buffer: [large_size, 1, 132] torch.uint8 + Output buffer where keys and scales will be written. + Must be a uint8 tensor to allow mixed-type storage. + Format: [:, 0, :128] = FP8 keys, [:, 0, 128:132] = float32 scales - assert k_fp8.shape[1] == k_scale.shape[1] - assert k_fp8_dim == 128, f"k_fp8 dim should be 128, got {k_fp8_dim}" - assert k_scale_dim == 1, f"k_scale dim should be 1, got {k_scale_dim}" - assert buffer.shape[2] == 132, f"buffer dim should be 132, got {buffer.shape[2]}" # 128 + 4 bytes - - # Reinterpret buffer as the appropriate types for storing - buffer_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn) - buffer_scale = buffer[:, :, 128:132].view(torch.float32)[:, :, :1] - - BLOCK_HEAD = triton.next_power_of_2(head_num) + Returns: + None (modifies O_buffer in-place) + + Example: + >>> k_fp8 = torch.randn(50, 128).to(torch.float8_e4m3fn).cuda() + >>> k_scale = torch.randn(50, 1).cuda() + >>> dest_loc = torch.randint(0, 1024, (50,), dtype=torch.int32).cuda() + >>> o_buffer = torch.zeros(1024, 1, 132, dtype=torch.uint8).cuda() + >>> destindex_copy_indexer_ks(k_fp8, k_scale, dest_loc, o_buffer) + >>> # Now o_buffer[dest_loc] contains the packed k_fp8 and k_scale data + """ + seq_len = DestLoc.shape[0] + head_dim = K_fp8.shape[1] + + assert head_dim == 128, f"Expected head_dim=128, got {head_dim}" + assert K_scale.shape[0] == seq_len + assert O_buffer.shape[2] == 132, f"Expected O_buffer last dim=132, got {O_buffer.shape[2]}" + grid = (seq_len,) num_warps = 1 - + _fwd_kernel_destindex_copy_indexer_ks[grid]( - k_fp8, - k_scale, - mem_index, - buffer_fp8, - buffer_scale, - k_fp8.stride(0), - k_fp8.stride(1), - k_fp8.stride(2), - k_scale.stride(0), - k_scale.stride(1), - k_scale.stride(2), - buffer_fp8.stride(0), - buffer_fp8.stride(1), - buffer_fp8.stride(2), - buffer_scale.stride(0), - buffer_scale.stride(1), - buffer_scale.stride(2), - head_num, - BLOCK_DMODEL=k_fp8_dim, - BLOCK_HEAD=BLOCK_HEAD, + K_fp8, + K_scale, + DestLoc, + O_buffer, + K_fp8.stride(0), + K_fp8.stride(1), + K_scale.stride(0), + K_scale.stride(1), + O_buffer.stride(0), + O_buffer.stride(1), + O_buffer.stride(2), + BLOCK_DMODEL=head_dim, num_warps=num_warps, num_stages=1, ) return -def test(): +def test_destindex_copy_indexer_ks(): + """Test the destindex_copy_indexer_ks kernel""" import torch.nn.functional as F - - # Test parameters similar to the usage in nsa_indexer_layer_inder.py - B, N_CTX, H, K_DIM = 4, 1024, 8, 128 # batch_size, seq_len, heads, k_dim - seq_len = 50 # number of tokens to copy - dtype_fp8 = torch.float8_e4m3fn - dtype_scale = torch.float32 - - # Create test data - k_fp8 = torch.randn((seq_len, H, K_DIM), dtype=dtype_fp8).cuda() - k_scale = torch.randn((seq_len, H, 1), dtype=dtype_scale).cuda() - mem_index = torch.randint(0, B * N_CTX, (seq_len,), dtype=torch.int32).cuda() - - # Create buffer [total_tokens, heads, 132] - buffer = torch.zeros((B * N_CTX, H, 132), dtype=torch.uint8).cuda() - - # Call the function - destindex_copy_indexer_ks(k_fp8, k_scale, mem_index, buffer) - - # Verify results - for i in range(seq_len): - dest_idx = mem_index[i].item() - # Check k_fp8 part - stored_fp8 = buffer[dest_idx, :, :128].view(dtype_fp8) - expected_fp8 = k_fp8[i] - assert torch.allclose(stored_fp8, expected_fp8, atol=1e-6), f"FP8 mismatch at index {i}" - - # Check k_scale part - stored_scale = buffer[dest_idx, :, 128:].view(dtype_scale)[:, :1] - expected_scale = k_scale[i] - assert torch.allclose(stored_scale, expected_scale, atol=1e-6), f"Scale mismatch at index {i}" - - print("All tests passed!") + + print("=" * 80) + print("Testing destindex_copy_indexer_ks") + print("=" * 80) + + # Test parameters + q_seq_len = 50 + head_dim = 128 + large_size = 1024 + dtype = torch.bfloat16 + fp8_type = torch.float8_e4m3fn + + # Create random destination indices + dest_loc = torch.randint(0, large_size, (q_seq_len,), device="cuda", dtype=torch.int32).unique() + actual_seq_len = len(dest_loc) + + # Create input tensors + k_bf16 = torch.randn((actual_seq_len, head_dim), dtype=dtype, device="cuda") + + # Quantize to FP8 + k_abs_max = k_bf16.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale = (k_abs_max / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8 = (k_bf16 / k_abs_max).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + # Create output buffer (as uint8 to allow reinterpretation) + o_buffer_uint8 = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + + # Run kernel + destindex_copy_indexer_ks(k_fp8, k_scale, dest_loc, o_buffer_uint8) + + # Extract results + k_fp8_out = o_buffer_uint8[:, 0, :128].view(fp8_type) + + # Extract scale by reinterpreting 4 bytes as float32 + scale_bytes = o_buffer_uint8[:, 0, 128:132].contiguous() + k_scale_out = scale_bytes.view(-1, 4).view(torch.float32).squeeze(-1) + + # Verify results at destination locations + k_fp8_extracted = k_fp8_out[dest_loc] + k_scale_extracted = k_scale_out[dest_loc] + + # Check FP8 values match + fp8_match = torch.allclose( + k_fp8_extracted.to(torch.float32), + k_fp8.to(torch.float32), + atol=0, rtol=0 + ) + + # Check scales match + scale_match = torch.allclose( + k_scale_extracted, + k_scale.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + # Check dequantized values + k_dequant_out = k_fp8_extracted.to(dtype) * k_scale_extracted.unsqueeze(-1) + cosine_sim = F.cosine_similarity(k_dequant_out, k_bf16, dim=-1).mean() + + print(f"Test with seq_len={actual_seq_len}, head_dim={head_dim}") + print(f" FP8 values match: {fp8_match}") + print(f" Scale values match: {scale_match}") + print(f" Cosine similarity after dequantization: {cosine_sim:.6f}") + + assert fp8_match, "FP8 values do not match!" + assert scale_match, "Scale values do not match!" + assert cosine_sim > 0.99, f"Cosine similarity too low: {cosine_sim}" + + print("✓ Basic test passed!") + print() + + # Test edge cases + print("Testing edge cases...") + + # Test with sequential indices + dest_loc_seq = torch.arange(20, device="cuda", dtype=torch.int32) + k_bf16_seq = torch.randn((20, head_dim), dtype=dtype, device="cuda") + k_abs_max_seq = k_bf16_seq.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_seq = (k_abs_max_seq / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_seq = (k_bf16_seq / k_abs_max_seq).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + o_buffer_seq = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_seq, k_scale_seq, dest_loc_seq, o_buffer_seq) + + k_fp8_out_seq = o_buffer_seq[:20, 0, :128].view(fp8_type) + scale_bytes_seq = o_buffer_seq[:20, 0, 128:132].contiguous() + k_scale_out_seq = scale_bytes_seq.view(-1, 4).view(torch.float32).squeeze(-1) + + fp8_match_seq = torch.allclose( + k_fp8_out_seq.to(torch.float32), + k_fp8_seq.to(torch.float32), + atol=0, rtol=0 + ) + scale_match_seq = torch.allclose( + k_scale_out_seq, + k_scale_seq.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + print(f" Sequential indices test: FP8={fp8_match_seq}, Scale={scale_match_seq}") + assert fp8_match_seq and scale_match_seq + print("✓ Edge case tests passed!") + print() + + # Test with single element + print("Testing single element...") + dest_loc_single = torch.tensor([42], device="cuda", dtype=torch.int32) + k_bf16_single = torch.randn((1, head_dim), dtype=dtype, device="cuda") + k_abs_max_single = k_bf16_single.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_single = (k_abs_max_single / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_single = (k_bf16_single / k_abs_max_single).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + o_buffer_single = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_single, k_scale_single, dest_loc_single, o_buffer_single) + + k_fp8_out_single = o_buffer_single[42:43, 0, :128].view(fp8_type) + scale_bytes_single = o_buffer_single[42:43, 0, 128:132].contiguous() + k_scale_out_single = scale_bytes_single.view(-1, 4).view(torch.float32).squeeze(-1) + + fp8_match_single = torch.allclose( + k_fp8_out_single.to(torch.float32), + k_fp8_single.to(torch.float32), + atol=0, rtol=0 + ) + scale_match_single = torch.allclose( + k_scale_out_single, + k_scale_single.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + print(f" Single element test: FP8={fp8_match_single}, Scale={scale_match_single}") + assert fp8_match_single and scale_match_single + print("✓ Single element test passed!") + print() + + print("=" * 80) + print("All tests passed successfully! ✓") + print("=" * 80) if __name__ == "__main__": - test() + test_destindex_copy_indexer_ks() \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py index e97454ba2..eb22fbb8f 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py +++ b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py @@ -1,156 +1,309 @@ import torch + import triton import triton.language as tl -import numpy @triton.jit def _fwd_kernel_extract_indexer_ks( - buffer_fp8, - buffer_scale, - mem_index, - k_fp8_out, - k_scale_out, - stride_buffer_fp8_bs, - stride_buffer_fp8_h, - stride_buffer_fp8_d, - stride_buffer_scale_bs, - stride_buffer_scale_h, - stride_buffer_scale_d, - stride_k_fp8_out_bs, - stride_k_fp8_out_d, - stride_k_scale_out_bs, + I_buffer, # Input buffer [large_size, 1, 132] uint8 + SrcLoc, # Source indices [req_size] int32/int64 + O_fp8, # Output FP8 [req_size, 128] float8_e4m3fn + O_scale, # Output scale [req_size] float32 + stride_i_bs, + stride_i_h, + stride_i_d, + stride_o_fp8_bs, + stride_o_fp8_d, + stride_o_scale_bs, BLOCK_DMODEL: tl.constexpr, ): + """ + Triton kernel to extract FP8 K values and their scales from an indexed buffer. + + This kernel is the inverse of destindex_copy_indexer_ks. It reads from a + compact buffer format where each entry contains: + - Bytes 0-127: FP8 key values (128 bytes) + - Bytes 128-131: Float32 scale (4 bytes) + + The source location for each output element is specified by SrcLoc. + """ cur_index = tl.program_id(0) - - # Load the memory index - mem_idx = tl.load(mem_index + cur_index).to(tl.int64) - - # Load k_fp8 data from buffer_fp8[mem_idx, 0, :] offs_d = tl.arange(0, BLOCK_DMODEL) - k_fp8_ptrs = buffer_fp8 + mem_idx * stride_buffer_fp8_bs + 0 * stride_buffer_fp8_h + offs_d * stride_buffer_fp8_d - k_fp8_data = tl.load(k_fp8_ptrs) - - # Load k_scale data from buffer_scale[mem_idx, 0, 0] - k_scale_ptr = buffer_scale + mem_idx * stride_buffer_scale_bs + 0 * stride_buffer_scale_h + 0 * stride_buffer_scale_d - k_scale_data = tl.load(k_scale_ptr) - - # Store k_fp8 output - k_fp8_out_ptrs = k_fp8_out + cur_index * stride_k_fp8_out_bs + offs_d * stride_k_fp8_out_d - tl.store(k_fp8_out_ptrs, k_fp8_data) - - # Store k_scale output - k_scale_out_ptr = k_scale_out + cur_index * stride_k_scale_out_bs - tl.store(k_scale_out_ptr, k_scale_data) + + # Load source index for this thread + src_index = tl.load(SrcLoc + cur_index).to(tl.int64) + + # Load K_fp8 from I_buffer[:, 0, :128] + i_k_ptrs = I_buffer + src_index * stride_i_bs + stride_i_d * offs_d + k_fp8_as_uint8 = tl.load(i_k_ptrs) + + # Convert uint8 to fp8 through bitcast + k_fp8 = k_fp8_as_uint8.to(tl.float8e4nv, bitcast=True) + + # Store K_fp8 to output + o_k_ptrs = O_fp8 + cur_index * stride_o_fp8_bs + stride_o_fp8_d * offs_d + tl.store(o_k_ptrs, k_fp8) + + # Load K_scale from I_buffer[:, 0, 128:132] (4 bytes for float32) + # Load 4 bytes and reconstruct float32 (little-endian) + i_scale_base_ptr = I_buffer + src_index * stride_i_bs + BLOCK_DMODEL * stride_i_d + + # Load 4 bytes individually and combine them into uint32 + byte0 = tl.load(i_scale_base_ptr + 0 * stride_i_d).to(tl.uint32) + byte1 = tl.load(i_scale_base_ptr + 1 * stride_i_d).to(tl.uint32) + byte2 = tl.load(i_scale_base_ptr + 2 * stride_i_d).to(tl.uint32) + byte3 = tl.load(i_scale_base_ptr + 3 * stride_i_d).to(tl.uint32) + + # Combine bytes into uint32 (little-endian: byte0 is LSB) + scale_as_uint32 = byte0 | (byte1 << 8) | (byte2 << 16) | (byte3 << 24) + + # Bitcast uint32 to float32 + k_scale = scale_as_uint32.to(tl.float32, bitcast=True) + + # Store scale to output + o_scale_ptr = O_scale + cur_index * stride_o_scale_bs + tl.store(o_scale_ptr, k_scale) + + return @torch.no_grad() -def extract_indexer_ks(buffer, mem_index): +def extract_indexer_ks(I_buffer: torch.Tensor, SrcLoc: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ - Extract k_fp8 and k_scale from the indexer memory buffer using Triton kernel. - + Extract FP8-quantized key values and their scales from indexed locations in a buffer. + + This function is the inverse operation of destindex_copy_indexer_ks. It's used in + the DeepSeek-V3.2 NSA (Neighbor-aware Sparse Attention) mechanism to retrieve + compressed key representations from a memory buffer. + Args: - buffer: Memory buffer of shape [total_tokens, heads, 132] with dtype uint8 - mem_index: Indices tensor of shape [seq_len] with dtype int32/int64 - + I_buffer: [large_size, 1, 132] torch.uint8 + Input buffer containing packed FP8 keys and float32 scales. + Format: [:, 0, :128] = FP8 keys, [:, 0, 128:132] = float32 scales + SrcLoc: [req_size] torch.int32 or torch.int64 + Source indices to extract from the input buffer + Returns: - k_fp8: Tensor of shape [seq_len, 128] with dtype float8_e4m3fn - k_scale: Tensor of shape [seq_len] with dtype float32 + tuple containing: + - K_fp8: [req_size, 128] torch.float8_e4m3fn + FP8-quantized key values + - K_scale: [req_size] torch.float32 + Quantization scales for each key + + Example: + >>> i_buffer = torch.zeros(1024, 1, 132, dtype=torch.uint8).cuda() + >>> src_loc = torch.tensor([10, 20, 30], dtype=torch.int32).cuda() + >>> k_fp8, k_scale = extract_indexer_ks(i_buffer, src_loc) + >>> # k_fp8.shape == [3, 128], k_scale.shape == [3] """ - seq_len = mem_index.shape[0] - assert buffer.shape[2] == 132, f"buffer dim should be 132, got {buffer.shape[2]}" - - # Reinterpret buffer as the appropriate types for Triton - buffer_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn) - buffer_scale = buffer[:, :, 128:132].view(torch.float32)[:, :, :1] - - # Prepare output tensors - k_fp8_out = torch.empty((seq_len, 128), dtype=torch.float8_e4m3fn, device=buffer.device) - k_scale_out = torch.empty((seq_len,), dtype=torch.float32, device=buffer.device) - - BLOCK_DMODEL = 128 - grid = (seq_len,) + req_size = SrcLoc.shape[0] + head_dim = 128 + + assert I_buffer.dtype == torch.uint8, f"Expected I_buffer dtype=uint8, got {I_buffer.dtype}" + assert I_buffer.shape[2] == 132, f"Expected I_buffer last dim=132, got {I_buffer.shape[2]}" + + # Allocate output tensors + O_fp8 = torch.empty((req_size, head_dim), dtype=torch.float8_e4m3fn, device=I_buffer.device) + O_scale = torch.empty((req_size,), dtype=torch.float32, device=I_buffer.device) + + grid = (req_size,) num_warps = 1 - + _fwd_kernel_extract_indexer_ks[grid]( - buffer_fp8, - buffer_scale, - mem_index, - k_fp8_out, - k_scale_out, - buffer_fp8.stride(0), - buffer_fp8.stride(1), - buffer_fp8.stride(2), - buffer_scale.stride(0), - buffer_scale.stride(1), - buffer_scale.stride(2), - k_fp8_out.stride(0), - k_fp8_out.stride(1), - k_scale_out.stride(0), - BLOCK_DMODEL=BLOCK_DMODEL, + I_buffer, + SrcLoc, + O_fp8, + O_scale, + I_buffer.stride(0), + I_buffer.stride(1), + I_buffer.stride(2), + O_fp8.stride(0), + O_fp8.stride(1), + O_scale.stride(0), + BLOCK_DMODEL=head_dim, num_warps=num_warps, num_stages=1, ) + + return O_fp8, O_scale - return k_fp8_out, k_scale_out - - -def test(): - # Test parameters similar to the usage in nsa_indexer_layer_inder.py - B, N_CTX, H = 4, 1024, 1 # batch_size, seq_len, heads (always 1 for this) - seq_len = 50 # number of tokens to extract - dtype_fp8 = torch.float8_e4m3fn - dtype_scale = torch.float32 - # Create test buffer [total_tokens, heads, 132] as uint8 - buffer = torch.zeros((B * N_CTX, H, 132), dtype=torch.uint8).cuda() - - # Fill with test data - simulate what destindex_copy_indexer_ks does - test_indices = torch.randint(0, B * N_CTX, (seq_len,), dtype=torch.int32).cuda() - # Generate fp8 data by converting from float32 - test_k_fp8_fp32 = torch.randn((seq_len, 128), dtype=torch.float32).cuda() - test_k_fp8 = test_k_fp8_fp32.to(dtype_fp8) - test_k_scale = torch.randn((seq_len,), dtype=dtype_scale).cuda() - - # Manually populate buffer as destindex_copy_indexer_ks would - for i in range(seq_len): - dest_idx = test_indices[i].item() - # Store fp8 data - buffer[dest_idx, 0, :128] = test_k_fp8[i].view(torch.uint8) - # Store scale data (4 bytes) - need to convert float32 to bytes - scale_bytes = test_k_scale[i].cpu().numpy().tobytes() - scale_bytes_np = numpy.frombuffer(scale_bytes, dtype=numpy.uint8) - buffer[dest_idx, 0, 128:132] = torch.from_numpy(scale_bytes_np).to(buffer.device) - - # Call our extraction function - extracted_fp8, extracted_scale = extract_indexer_ks(buffer, test_indices) - - # Verify results - print(f"Original k_fp8 shape: {test_k_fp8.shape}, dtype: {test_k_fp8.dtype}") - print(f"Extracted k_fp8 shape: {extracted_fp8.shape}, dtype: {extracted_fp8.dtype}") - print(f"Original k_scale shape: {test_k_scale.shape}, dtype: {test_k_scale.dtype}") - print(f"Extracted k_scale shape: {extracted_scale.shape}, dtype: {extracted_scale.dtype}") - - # Check if extraction matches (convert fp8 to float32 for comparison) - # Use higher tolerance for fp8 due to quantization precision - fp8_match = torch.allclose(test_k_fp8_fp32, extracted_fp8.float(), atol=0.1, rtol=0.1) - scale_match = torch.allclose(test_k_scale, extracted_scale, atol=1e-6) - - print(f"FP8 data matches: {fp8_match}") - print(f"Scale data matches: {scale_match}") - - if fp8_match and scale_match: - print("All tests passed!") - else: - print("Test failed!") - if not fp8_match: - print("First few fp8 values:") - print(f"Original: {test_k_fp8_fp32[0, :5]}") - print(f"Extracted: {extracted_fp8.float()[0, :5]}") - if not scale_match: - print(f"Max scale diff: {torch.max(torch.abs(test_k_scale - extracted_scale))}") +def test_extract_indexer_ks(): + """Test the extract_indexer_ks kernel against the copy kernel""" + import torch.nn.functional as F + from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks + + print("=" * 80) + print("Testing extract_indexer_ks") + print("=" * 80) + + # Test parameters + q_seq_len = 50 + head_dim = 128 + large_size = 1024 + dtype = torch.bfloat16 + fp8_type = torch.float8_e4m3fn + + # Create random indices for writing + write_indices = torch.randint(0, large_size, (q_seq_len,), device="cuda", dtype=torch.int32).unique() + actual_seq_len = len(write_indices) + + # Create input tensors + k_bf16_original = torch.randn((actual_seq_len, head_dim), dtype=dtype, device="cuda") + + # Quantize to FP8 + k_abs_max = k_bf16_original.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_original = (k_abs_max / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_original = (k_bf16_original / k_abs_max).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + # Create buffer and write data using destindex_copy_indexer_ks + buffer = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_original, k_scale_original, write_indices, buffer) + + # Now extract the data back using extract_indexer_ks + k_fp8_extracted, k_scale_extracted = extract_indexer_ks(buffer, write_indices) + + # Verify FP8 values match + fp8_match = torch.allclose( + k_fp8_extracted.to(torch.float32), + k_fp8_original.to(torch.float32), + atol=0, rtol=0 + ) + + # Verify scales match + scale_match = torch.allclose( + k_scale_extracted, + k_scale_original.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + # Check dequantized values + k_dequant_extracted = k_fp8_extracted.to(dtype) * k_scale_extracted.unsqueeze(-1) + cosine_sim = F.cosine_similarity(k_dequant_extracted, k_bf16_original, dim=-1).mean() + + print(f"Test with seq_len={actual_seq_len}, head_dim={head_dim}") + print(f" FP8 values match: {fp8_match}") + print(f" Scale values match: {scale_match}") + print(f" Cosine similarity after dequantization: {cosine_sim:.6f}") + + assert fp8_match, "FP8 values do not match!" + assert scale_match, "Scale values do not match!" + assert cosine_sim > 0.99, f"Cosine similarity too low: {cosine_sim}" + + print("✓ Basic test passed!") + print() + + # Test with sequential indices + print("Testing sequential indices...") + write_indices_seq = torch.arange(20, device="cuda", dtype=torch.int32) + k_bf16_seq = torch.randn((20, head_dim), dtype=dtype, device="cuda") + k_abs_max_seq = k_bf16_seq.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_seq = (k_abs_max_seq / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_seq = (k_bf16_seq / k_abs_max_seq).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + buffer_seq = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_seq, k_scale_seq, write_indices_seq, buffer_seq) + k_fp8_ext_seq, k_scale_ext_seq = extract_indexer_ks(buffer_seq, write_indices_seq) + + fp8_match_seq = torch.allclose( + k_fp8_ext_seq.to(torch.float32), + k_fp8_seq.to(torch.float32), + atol=0, rtol=0 + ) + scale_match_seq = torch.allclose( + k_scale_ext_seq, + k_scale_seq.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + print(f" Sequential indices: FP8={fp8_match_seq}, Scale={scale_match_seq}") + assert fp8_match_seq and scale_match_seq + print("✓ Sequential test passed!") + print() + + # Test with single element + print("Testing single element...") + write_idx_single = torch.tensor([42], device="cuda", dtype=torch.int32) + k_bf16_single = torch.randn((1, head_dim), dtype=dtype, device="cuda") + k_abs_max_single = k_bf16_single.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_single = (k_abs_max_single / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_single = (k_bf16_single / k_abs_max_single).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + buffer_single = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_single, k_scale_single, write_idx_single, buffer_single) + k_fp8_ext_single, k_scale_ext_single = extract_indexer_ks(buffer_single, write_idx_single) + + fp8_match_single = torch.allclose( + k_fp8_ext_single.to(torch.float32), + k_fp8_single.to(torch.float32), + atol=0, rtol=0 + ) + scale_match_single = torch.allclose( + k_scale_ext_single, + k_scale_single.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + print(f" Single element: FP8={fp8_match_single}, Scale={scale_match_single}") + assert fp8_match_single and scale_match_single + print("✓ Single element test passed!") + print() + + # Test with larger batch to check performance characteristics + print("Testing larger batch (performance check)...") + write_indices_large = torch.randint(0, large_size * 10, (500,), device="cuda", dtype=torch.int32).unique() + actual_large_len = len(write_indices_large) + k_bf16_large = torch.randn((actual_large_len, head_dim), dtype=dtype, device="cuda") + k_abs_max_large = k_bf16_large.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_large = (k_abs_max_large / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_large = (k_bf16_large / k_abs_max_large).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + buffer_large = torch.zeros((large_size * 10, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_large, k_scale_large, write_indices_large, buffer_large) + + # Warm up + for _ in range(3): + _ = extract_indexer_ks(buffer_large, write_indices_large) + + # Time it + torch.cuda.synchronize() + import time + start = time.time() + for _ in range(100): + k_fp8_ext_large, k_scale_ext_large = extract_indexer_ks(buffer_large, write_indices_large) + torch.cuda.synchronize() + elapsed = time.time() - start + + fp8_match_large = torch.allclose( + k_fp8_ext_large.to(torch.float32), + k_fp8_large.to(torch.float32), + atol=0, rtol=0 + ) + scale_match_large = torch.allclose( + k_scale_ext_large, + k_scale_large.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + print(f" Large batch (size={actual_large_len}): FP8={fp8_match_large}, Scale={scale_match_large}") + print(f" Average time per call: {elapsed/100*1000:.3f} ms") + assert fp8_match_large and scale_match_large + print("✓ Large batch test passed!") + print() + + print("=" * 80) + print("All tests passed successfully! ✓") + print("=" * 80) if __name__ == "__main__": - test() + test_extract_indexer_ks() From c0f072bda2eb057cfabb44666402a2212fc68761 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 10 Nov 2025 14:32:31 +0000 Subject: [PATCH 12/12] fix cudagraph --- lightllm/models/deepseek3_2/infer_struct.py | 168 +++++++++++++++++--- 1 file changed, 147 insertions(+), 21 deletions(-) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index c122c6a7e..db6e61a1c 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -1,8 +1,10 @@ import torch +import weakref from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager class Deepseek3_2FlashAttentionStateInfo(Deepseek2FlashAttentionStateInfo): + _shared_nsa_buffers = None def __init__(self): super().__init__() @@ -14,8 +16,42 @@ def __init__(self): self.index_topk = 2048 return + @classmethod + def get_nsa_buffers(cls, graph_max_batch_size: int, max_seq_len: int): + """Get or create pre-allocated buffers for CUDA graph execution""" + if cls._shared_nsa_buffers is None: + # Pre-allocate buffers for max possible sizes + max_total_q_tokens = graph_max_batch_size * max_seq_len + max_total_tokens = graph_max_batch_size * max_seq_len + + cls._shared_nsa_buffers = [ + { + 'ks': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'ke': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'lengths': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'page_table_size_1': torch.empty(graph_max_batch_size, max_seq_len, dtype=torch.int, device='cuda'), + 'req_all_mem_index': torch.empty(max_total_tokens, dtype=torch.int64, device='cuda'), + 'nsa_cache_seqlens': torch.empty(graph_max_batch_size, dtype=torch.int32, device='cuda'), + 'nsa_cu_seqlens_k': torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device='cuda'), + }, + { # Second buffer for microbatch overlap if needed + 'ks': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'ke': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'lengths': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'page_table_size_1': torch.empty(graph_max_batch_size, max_seq_len, dtype=torch.int, device='cuda'), + 'req_all_mem_index': torch.empty(max_total_tokens, dtype=torch.int64, device='cuda'), + 'nsa_cache_seqlens': torch.empty(graph_max_batch_size, dtype=torch.int32, device='cuda'), + 'nsa_cu_seqlens_k': torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device='cuda'), + } + ] + return cls._shared_nsa_buffers + def init_some_extra_state(self, model, input_ids: torch.Tensor): super().init_some_extra_state(model, input_ids) + + # Store weak reference to model for accessing graph parameters + self._model_ref = weakref.ref(model) + assert isinstance(self.mem_manager, Deepseek3_2MemoryManager) self.indexer_ks_mem_manager = self.mem_manager.indexer_ks_mem_manager @@ -29,11 +65,34 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if self.b_ready_cache_len is None: self.b_ready_cache_len = self.b_seq_len - self.b_q_seq_len - self.nsa_cache_seqlens = self.b_att_seq_len.clamp(max=self.index_topk) + # Check if we can use CUDA graph based on batch size and max_len constraints + use_cuda_graph_buffers = False + if (hasattr(model, 'graph_max_batch_size') and + hasattr(model, 'graph_max_len_in_batch') and + self.batch_size <= model.graph_max_batch_size and + self.max_len_in_batch <= model.graph_max_len_in_batch): + use_cuda_graph_buffers = True + + # Setup nsa_cache_seqlens and nsa_cu_seqlens_k with pre-allocated buffers if using CUDA graph + if use_cuda_graph_buffers: + buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) + buffer = buffers[self.microbatch_index] + + # Use views into pre-allocated buffers + self.nsa_cache_seqlens = buffer['nsa_cache_seqlens'][:self.batch_size] + self.nsa_cu_seqlens_k = buffer['nsa_cu_seqlens_k'][:self.batch_size + 1] + else: + # Create new tensors dynamically + self.nsa_cache_seqlens = torch.empty(self.batch_size, dtype=torch.int32, device='cuda') + self.nsa_cu_seqlens_k = torch.empty(self.batch_size + 1, dtype=torch.int32, device='cuda') + + # Calculate actual values + self.nsa_cache_seqlens.copy_(self.b_att_seq_len.clamp(max=self.index_topk)) assert self.nsa_cache_seqlens.dtype == torch.int32 - self.nsa_cu_seqlens_k = torch.nn.functional.pad( - torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32), (1, 0) - ) + + # Compute cumulative sum with padding + torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32, out=self.nsa_cu_seqlens_k[1:]) + self.nsa_cu_seqlens_k[0] = 0 # Pre-compute NSA indexer indexing structures self._init_nsa_indexing_structures() @@ -46,22 +105,89 @@ def _init_nsa_indexing_structures(self): lengths_list = [] offset = 0 num_seq_len = self.b_req_idx.shape[0] - self.page_table_size_1 = torch.zeros((num_seq_len, self.b_seq_len.max()), dtype=torch.int, device='cuda') + max_seq_len = self.b_seq_len.max().item() + + # Calculate total sizes needed + total_q_len = sum(self.b_q_seq_len[i].item() for i in range(num_seq_len)) + total_seq_len = sum(self.b_seq_len[i].item() for i in range(num_seq_len)) + + # Check if we should use CUDA graph buffers + use_cuda_graph_buffers = False + if hasattr(self, '_model_ref'): + model = self._model_ref() + if (model is not None and + hasattr(model, 'graph_max_batch_size') and + hasattr(model, 'graph_max_len_in_batch') and + self.batch_size <= model.graph_max_batch_size and + self.max_len_in_batch <= model.graph_max_len_in_batch): + use_cuda_graph_buffers = True + + if use_cuda_graph_buffers: + # Use pre-allocated buffers for CUDA graph + model = self._model_ref() + buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) + buffer = buffers[self.microbatch_index] + + # Use views into pre-allocated buffers + self.ks = buffer['ks'][:total_q_len] + self.ke = buffer['ke'][:total_q_len] + self.lengths = buffer['lengths'][:total_q_len] + self.page_table_size_1 = buffer['page_table_size_1'][:num_seq_len, :max_seq_len] + self.req_all_mem_index = buffer['req_all_mem_index'][:total_seq_len] + + # Zero out page_table_size_1 before filling + self.page_table_size_1.zero_() + + # Compute and copy values into the pre-allocated buffer views + ks_offset = 0 + ke_offset = 0 + lengths_offset = 0 + req_offset = 0 + seq_offset = 0 + + for i in range(num_seq_len): + seq_len = self.b_seq_len[i].item() + q_seq_len = self.b_q_seq_len[i].item() + mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] + + # Copy req_all_mem_index + self.req_all_mem_index[req_offset:req_offset + seq_len] = mem_index + + # Fill page_table_size_1 + self.page_table_size_1[i, :seq_len] = mem_index + + # Fill ks, ke, lengths + self.ks[ks_offset:ks_offset + q_seq_len].fill_(seq_offset) + self.ke[ke_offset:ke_offset + q_seq_len] = torch.arange( + seq_offset + 1, seq_offset + q_seq_len + 1, dtype=torch.int, device='cuda' + ) + self.lengths[lengths_offset:lengths_offset + q_seq_len] = torch.arange( + seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda' + ) + + ks_offset += q_seq_len + ke_offset += q_seq_len + lengths_offset += q_seq_len + req_offset += seq_len + seq_offset += seq_len + else: + # Original dynamic allocation for non-CUDA graph mode + self.page_table_size_1 = torch.zeros((num_seq_len, max_seq_len), dtype=torch.int, device='cuda') - for i in range(num_seq_len): - seq_len = self.b_seq_len[i] - q_seq_len = self.b_q_seq_len[i] - mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] - req_all_mem_index_list.append(mem_index) - self.page_table_size_1[i, :seq_len] = mem_index - ks = torch.zeros(q_seq_len, dtype=torch.int, device='cuda') + offset - ke = torch.arange(q_seq_len, dtype=torch.int, device='cuda') + offset + 1 - ks_list.append(ks) - ke_list.append(ke) - lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda')) - offset += seq_len + for i in range(num_seq_len): + seq_len = self.b_seq_len[i].item() + q_seq_len = self.b_q_seq_len[i].item() + mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] + req_all_mem_index_list.append(mem_index) + self.page_table_size_1[i, :seq_len] = mem_index + ks = torch.zeros(q_seq_len, dtype=torch.int, device='cuda') + offset + ke = torch.arange(q_seq_len, dtype=torch.int, device='cuda') + offset + 1 + ks_list.append(ks) + ke_list.append(ke) + lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda')) + offset += seq_len - self.req_all_mem_index = torch.cat(req_all_mem_index_list, dim=0) - self.ks = torch.cat(ks_list, dim=0) - self.ke = torch.cat(ke_list, dim=0) - self.lengths = torch.cat(lengths_list, dim=0) \ No newline at end of file + self.req_all_mem_index = torch.cat(req_all_mem_index_list, dim=0) + self.ks = torch.cat(ks_list, dim=0) + self.ke = torch.cat(ke_list, dim=0) + self.lengths = torch.cat(lengths_list, dim=0) \ No newline at end of file