diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py index 87a7b361e..d64b17bf9 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py @@ -67,9 +67,6 @@ def __init__( self.global_rank_ = get_global_rank() self.redundancy_expert_num = get_redundancy_expert_num() self.redundancy_expert_ids = get_redundancy_expert_ids(layer_num) - logger.info( - f"global_rank {self.global_rank_} layerindex {layer_num} redundancy_expertids: {self.redundancy_expert_ids}" - ) self.redundancy_expert_ids_tensor = torch.tensor(self.redundancy_expert_ids, dtype=torch.int64, device="cuda") self.routed_expert_counter_tensor = torch.zeros((self.n_routed_experts,), dtype=torch.int64, device="cuda") self.total_expert_num_contain_redundancy = ( diff --git a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py index 48167a067..07b3bf69c 100644 --- a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py @@ -36,8 +36,17 @@ def load_hf_weights(self, weights): """ for attr_name in dir(self): attr = getattr(self, attr_name, None) - if isinstance(attr, MultiMMWeightTpl): + if isinstance(attr, TransformerLayerWeight): + attr.load_hf_weights(weights) + elif isinstance(attr, MultiMMWeightTpl): with self.lock: attr.load_hf_weights(weights) elif isinstance(attr, BaseWeight): attr.load_hf_weights(weights) + + def verify_load(self): + for attr_name in dir(self): + attr = getattr(self, attr_name, None) + if isinstance(attr, TransformerLayerWeight): + attr.verify_load() + super().verify_load() \ No newline at end of file diff --git a/lightllm/common/deepseek2_fp8kv_mem_manager.py b/lightllm/common/deepseek2_fp8kv_mem_manager.py index 00699f4b1..ffa0f2274 100644 --- a/lightllm/common/deepseek2_fp8kv_mem_manager.py +++ b/lightllm/common/deepseek2_fp8kv_mem_manager.py @@ -3,6 +3,6 @@ class Deepseek2FP8KVMemoryManager(Deepseek2MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False): # scale被追加到kv_buffer末尾, 因此加2, dtype统一改成uint8 - super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction) + super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction, is_sub_mem_manager) diff --git a/lightllm/common/deepseek2_mem_manager.py b/lightllm/common/deepseek2_mem_manager.py index 4f106bdcf..46fd87dcf 100644 --- a/lightllm/common/deepseek2_mem_manager.py +++ b/lightllm/common/deepseek2_mem_manager.py @@ -14,8 +14,8 @@ class Deepseek2MemoryManager(MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): - super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_sub_mem_manager) def get_cell_size(self): return self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index 57ae9838b..167ce2760 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -19,7 +19,7 @@ class MemoryManager: - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False): self.size = size self.head_num = head_num self.head_dim = head_dim @@ -41,15 +41,16 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False self.can_use_mem_size = self.size - # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 - from lightllm.utils.envs_utils import get_unique_server_name + if not is_sub_mem_manager: + # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 + from lightllm.utils.envs_utils import get_unique_server_name - rank_in_node = get_current_rank_in_node() - self.shared_can_use_token_num = SharedInt( - f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" - ) + rank_in_node = get_current_rank_in_node() + self.shared_can_use_token_num = SharedInt( + f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" + ) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) self._init_buffers( self.size, dtype, diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py index a919f7b28..c62a2572f 100644 --- a/lightllm/common/triton_utils/autotuner.py +++ b/lightllm/common/triton_utils/autotuner.py @@ -62,7 +62,7 @@ def autotune( as needed before invocation. """ - def decorator(fn): + def decorator(fn: Callable) -> Callable: return Autotuner( fn=fn, kernel_name=kernel_name, diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 96329eabe..9c9bdb9df 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -20,6 +20,7 @@ from lightllm.models.gemma_2b.model import Gemma_2bTpPartModel from lightllm.models.phi3.model import Phi3TpPartModel from lightllm.models.deepseek2.model import Deepseek2TpPartModel +from lightllm.models.deepseek3_2.model import Deepseek3_2TpPartModel from lightllm.models.internvl.model import ( InternVLLlamaTpPartModel, InternVLPhi3TpPartModel, diff --git a/lightllm/models/deepseek3_2/__init__.py b/lightllm/models/deepseek3_2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py new file mode 100644 index 000000000..db6e61a1c --- /dev/null +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -0,0 +1,193 @@ +import torch +import weakref +from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo +from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager + +class Deepseek3_2FlashAttentionStateInfo(Deepseek2FlashAttentionStateInfo): + _shared_nsa_buffers = None + + def __init__(self): + super().__init__() + self.lengths = None + self.page_table_size_1 = None + self.ks = None + self.ke = None + self.nsa_cu_seqlens_k = None + self.index_topk = 2048 + return + + @classmethod + def get_nsa_buffers(cls, graph_max_batch_size: int, max_seq_len: int): + """Get or create pre-allocated buffers for CUDA graph execution""" + if cls._shared_nsa_buffers is None: + # Pre-allocate buffers for max possible sizes + max_total_q_tokens = graph_max_batch_size * max_seq_len + max_total_tokens = graph_max_batch_size * max_seq_len + + cls._shared_nsa_buffers = [ + { + 'ks': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'ke': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'lengths': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'page_table_size_1': torch.empty(graph_max_batch_size, max_seq_len, dtype=torch.int, device='cuda'), + 'req_all_mem_index': torch.empty(max_total_tokens, dtype=torch.int64, device='cuda'), + 'nsa_cache_seqlens': torch.empty(graph_max_batch_size, dtype=torch.int32, device='cuda'), + 'nsa_cu_seqlens_k': torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device='cuda'), + }, + { # Second buffer for microbatch overlap if needed + 'ks': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'ke': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'lengths': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'page_table_size_1': torch.empty(graph_max_batch_size, max_seq_len, dtype=torch.int, device='cuda'), + 'req_all_mem_index': torch.empty(max_total_tokens, dtype=torch.int64, device='cuda'), + 'nsa_cache_seqlens': torch.empty(graph_max_batch_size, dtype=torch.int32, device='cuda'), + 'nsa_cu_seqlens_k': torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device='cuda'), + } + ] + return cls._shared_nsa_buffers + + def init_some_extra_state(self, model, input_ids: torch.Tensor): + super().init_some_extra_state(model, input_ids) + + # Store weak reference to model for accessing graph parameters + self._model_ref = weakref.ref(model) + + assert isinstance(self.mem_manager, Deepseek3_2MemoryManager) + self.indexer_ks_mem_manager = self.mem_manager.indexer_ks_mem_manager + + # Ensure b_ready_cache_len is set for both prefill and decode modes + if self.is_prefill: + # b_ready_cache_len is already set in basemodel.py for prefill + pass + else: + # In decode mode, b_ready_cache_len should be b_seq_len - b_q_seq_len + # since b_q_seq_len represents the new tokens being processed + if self.b_ready_cache_len is None: + self.b_ready_cache_len = self.b_seq_len - self.b_q_seq_len + + # Check if we can use CUDA graph based on batch size and max_len constraints + use_cuda_graph_buffers = False + if (hasattr(model, 'graph_max_batch_size') and + hasattr(model, 'graph_max_len_in_batch') and + self.batch_size <= model.graph_max_batch_size and + self.max_len_in_batch <= model.graph_max_len_in_batch): + use_cuda_graph_buffers = True + + # Setup nsa_cache_seqlens and nsa_cu_seqlens_k with pre-allocated buffers if using CUDA graph + if use_cuda_graph_buffers: + buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) + buffer = buffers[self.microbatch_index] + + # Use views into pre-allocated buffers + self.nsa_cache_seqlens = buffer['nsa_cache_seqlens'][:self.batch_size] + self.nsa_cu_seqlens_k = buffer['nsa_cu_seqlens_k'][:self.batch_size + 1] + else: + # Create new tensors dynamically + self.nsa_cache_seqlens = torch.empty(self.batch_size, dtype=torch.int32, device='cuda') + self.nsa_cu_seqlens_k = torch.empty(self.batch_size + 1, dtype=torch.int32, device='cuda') + + # Calculate actual values + self.nsa_cache_seqlens.copy_(self.b_att_seq_len.clamp(max=self.index_topk)) + assert self.nsa_cache_seqlens.dtype == torch.int32 + + # Compute cumulative sum with padding + torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32, out=self.nsa_cu_seqlens_k[1:]) + self.nsa_cu_seqlens_k[0] = 0 + + # Pre-compute NSA indexer indexing structures + self._init_nsa_indexing_structures() + + def _init_nsa_indexing_structures(self): + """Pre-compute ks, ke, lengths, and page_table_size_1 for NSA indexer""" + req_all_mem_index_list = [] + ks_list = [] + ke_list = [] + lengths_list = [] + offset = 0 + num_seq_len = self.b_req_idx.shape[0] + max_seq_len = self.b_seq_len.max().item() + + # Calculate total sizes needed + total_q_len = sum(self.b_q_seq_len[i].item() for i in range(num_seq_len)) + total_seq_len = sum(self.b_seq_len[i].item() for i in range(num_seq_len)) + + # Check if we should use CUDA graph buffers + use_cuda_graph_buffers = False + if hasattr(self, '_model_ref'): + model = self._model_ref() + if (model is not None and + hasattr(model, 'graph_max_batch_size') and + hasattr(model, 'graph_max_len_in_batch') and + self.batch_size <= model.graph_max_batch_size and + self.max_len_in_batch <= model.graph_max_len_in_batch): + use_cuda_graph_buffers = True + + if use_cuda_graph_buffers: + # Use pre-allocated buffers for CUDA graph + model = self._model_ref() + buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) + buffer = buffers[self.microbatch_index] + + # Use views into pre-allocated buffers + self.ks = buffer['ks'][:total_q_len] + self.ke = buffer['ke'][:total_q_len] + self.lengths = buffer['lengths'][:total_q_len] + self.page_table_size_1 = buffer['page_table_size_1'][:num_seq_len, :max_seq_len] + self.req_all_mem_index = buffer['req_all_mem_index'][:total_seq_len] + + # Zero out page_table_size_1 before filling + self.page_table_size_1.zero_() + + # Compute and copy values into the pre-allocated buffer views + ks_offset = 0 + ke_offset = 0 + lengths_offset = 0 + req_offset = 0 + seq_offset = 0 + + for i in range(num_seq_len): + seq_len = self.b_seq_len[i].item() + q_seq_len = self.b_q_seq_len[i].item() + mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] + + # Copy req_all_mem_index + self.req_all_mem_index[req_offset:req_offset + seq_len] = mem_index + + # Fill page_table_size_1 + self.page_table_size_1[i, :seq_len] = mem_index + + # Fill ks, ke, lengths + self.ks[ks_offset:ks_offset + q_seq_len].fill_(seq_offset) + self.ke[ke_offset:ke_offset + q_seq_len] = torch.arange( + seq_offset + 1, seq_offset + q_seq_len + 1, dtype=torch.int, device='cuda' + ) + self.lengths[lengths_offset:lengths_offset + q_seq_len] = torch.arange( + seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda' + ) + + ks_offset += q_seq_len + ke_offset += q_seq_len + lengths_offset += q_seq_len + req_offset += seq_len + seq_offset += seq_len + else: + # Original dynamic allocation for non-CUDA graph mode + self.page_table_size_1 = torch.zeros((num_seq_len, max_seq_len), dtype=torch.int, device='cuda') + + for i in range(num_seq_len): + seq_len = self.b_seq_len[i].item() + q_seq_len = self.b_q_seq_len[i].item() + mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] + req_all_mem_index_list.append(mem_index) + self.page_table_size_1[i, :seq_len] = mem_index + ks = torch.zeros(q_seq_len, dtype=torch.int, device='cuda') + offset + ke = torch.arange(q_seq_len, dtype=torch.int, device='cuda') + offset + 1 + ks_list.append(ks) + ke_list.append(ke) + lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda')) + offset += seq_len + + self.req_all_mem_index = torch.cat(req_all_mem_index_list, dim=0) + self.ks = torch.cat(ks_list, dim=0) + self.ke = torch.cat(ke_list, dim=0) + self.lengths = torch.cat(lengths_list, dim=0) \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py new file mode 100644 index 000000000..df045dd2d --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -0,0 +1,135 @@ +from sgl_kernel import fast_topk_transform_fused +import deep_gemm +import torch +import torch.nn.functional as F + +from lightllm.common.basemodel.layer_infer.base_layer_infer import BaseLayerInfer +from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo +from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant +from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager +from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks +from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks +from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +class NSAIndexerInfer(BaseLayerInfer): + def __init__(self, layer_idx, network_config, mode=[]): + super().__init__() + self.layer_idx_ = layer_idx + self.network_config_ = network_config + self.mode = mode + self.index_topk = network_config["index_topk"] + self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_ + self.tp_k_head_num_ = 1 + self.tp_v_head_num_ = 1 + self.qk_nope_head_dim = network_config["qk_nope_head_dim"] + self.qk_rope_head_dim = network_config["qk_rope_head_dim"] + self.index_head_dim = network_config["index_head_dim"] + self.eps = network_config["rms_norm_eps"] + self.block_size = network_config["quantization_config"]["weight_block_size"][0] + self.scale_fmt = network_config["quantization_config"]["scale_fmt"] + self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) + self.index_n_heads = network_config["index_n_heads"] + self.index_n_heads_scale = (self.index_n_heads ** -0.5) * self.softmax_scale + + return + + def ref_fp8_mqa_logits(self, q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, cost_only: bool = False): + seq_len_kv = kv.shape[0] + + if cost_only: + start = cu_seqlen_ks.clamp(min=0, max=seq_len_kv) + end = cu_seqlen_ke.clamp(min=0, max=seq_len_kv) + count_ones_per_row = (end - start).clamp(min=0) + return count_ones_per_row.sum() + + k = kv + q = q.float() + k = k.float() + + mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None] + mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None] + mask = mask_lo & mask_hi + + score = torch.einsum('mhd,nd->hmn', q, k) + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float('-inf')) + + cost = mask.sum() + return logits, cost + + def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, + infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight) -> torch.Tensor: + + q, k = self._get_q_k_bf16(hidden_states, q_lora, infer_state, layer_weight) + q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) + k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) + + destindex_copy_indexer_ks( + k_fp8, + k_scale, + infer_state.mem_index, + infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] + ) + + weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale + weights = weights.unsqueeze(-1) * q_scale + + ks = infer_state.ks + ke = infer_state.ke + lengths = infer_state.lengths + page_table_1 = infer_state.page_table_size_1 + + # Use efficient Triton kernel to extract FP8 keys and scales from buffer + k_fp8_, k_scale_ = extract_indexer_ks( + infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_], + infer_state.req_all_mem_index + ) + + logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) + + return fast_topk_transform_fused( + score=logits, + lengths=lengths, + page_table_size_1=page_table_1, + cu_seqlens_q=infer_state.cu_seqlens_q, + topk=self.index_topk, + ) + + + @staticmethod + def _rotate_activation(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.bfloat16 + from sgl_kernel import hadamard_transform + + hidden_size = x.size(-1) + assert ( + hidden_size & (hidden_size - 1) + ) == 0, "Hidden size must be a power of 2 for Hadamard transform." + return hadamard_transform(x, scale=hidden_size**-0.5) + + def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, + infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight): + q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) + k = layer_weight.wk_proj_.mm(hidden_states) + + # TODO + k = F.layer_norm( + k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps + ).type_as(k) + + rotary_emb_fwd( + q[:, :, : self.qk_rope_head_dim], + k[:, None, : self.qk_rope_head_dim], + infer_state.position_cos, + infer_state.position_sin, + ) + + q = self._rotate_activation(q) + k = self._rotate_activation(k) + return q, k diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..df5220427 --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -0,0 +1,119 @@ +from functools import partial +from typing import override + +import torch +from sgl_kernel.flash_mla import flash_mla_sparse_fwd +from sgl_kernel.flash_attn import flash_attn_with_kvcache + +from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer +from lightllm.models.deepseek3_2.layer_infer.nsa_indexer_layer_inder import NSAIndexerInfer +from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo +from lightllm.models.deepseek3_2.triton_kernel.token_group_quant import per_token_group_quant_mla_deep_gemm_masked_fp8 +from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward +from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd + + +class Deepseek3_2TransformerLayerInfer(Deepseek2TransformerLayerInfer): + def __init__(self, layer_num, network_config, mode=[]): + self.index_topk = network_config["index_topk"] + super().__init__(layer_num, network_config, mode) + + self.indexer = NSAIndexerInfer( + layer_idx=self.layer_num_, + network_config=self.network_config_, + mode=mode + ) + self.topk_indices = None + return + + @override + def _get_qkv( + self, + input: torch.Tensor, + infer_state: Deepseek3_2FlashAttentionStateInfo, + layer_weight: Deepseek3_2TransformerLayerWeight, + ) -> torch.Tensor: + input = input.view(-1, self.embed_dim_) + + q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 + ) + q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) + + self.topk_indices = self.indexer.get_indices(input, q, infer_state, layer_weight.indexer_layer_weight) + + q = layer_weight.q_b_proj_.mm(q) + cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) + q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) + q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + rmsnorm_forward( + cache_kv[:, :, : self.kv_lora_rank], + weight=layer_weight.kv_a_layernorm_.weight, + eps=self.eps_, + out=cache_kv[:, :, : self.kv_lora_rank], + ) + + rotary_emb_fwd( + q_rope, + cache_kv[:, :, self.kv_lora_rank :], + infer_state.position_cos, + infer_state.position_sin, + ) + return q, cache_kv + + @override + def _bind_attention(self): + if "triton_fp8kv" in self.mode: + self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_fp8, self) + else: + self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) + + self._context_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._nsa_context_attention_kernel, self) + self._token_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._nsa_token_attention_kernel, self) + pass + + def _nsa_context_attention_kernel( + self, + q: torch.Tensor, + kv, + infer_state: Deepseek3_2FlashAttentionStateInfo, + layer_weight: Deepseek3_2TransformerLayerWeight, + out=None, + ) -> torch.Tensor: + + q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + q_all = torch.cat([q_nope, q_rope], dim=-1) + mla_out, _, _ = flash_mla_sparse_fwd( + q=q_all, + kv=infer_state.mem_manager.kv_buffer[self.layer_num_], + indices=self.topk_indices.unsqueeze(1), + sm_scale=self.softmax_scale, + d_v=self.kv_lora_rank, + ) + return mla_out + + def _nsa_token_attention_kernel( + self, q, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None + ): + q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) + kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) + + o_tensor = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope, + v_cache=kv_nope, + qv=q_nope, + page_table=self.topk_indices, + cache_seqlens=infer_state.nsa_cache_seqlens, + cu_seqlens_q=infer_state.cu_seqlens_q, + cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, + max_seqlen_q=infer_state.max_q_seq_len, + softmax_scale=self.softmax_scale, + causal=True, + ) + return o_tensor \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py new file mode 100644 index 000000000..47e0bfdac --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py @@ -0,0 +1,49 @@ +from typing_extensions import override + +import torch + +from lightllm.common.basemodel.layer_weights.transformer_layer_weight import TransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, NormWeight + + +class NSAIndexerWeight(TransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode, quant_cfg): + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + return + + @override + def _init_weight(self): + prefix = f"model.layers.{self.layer_num_}.self_attn.indexer" + + self.wq_b_proj_ = ROWMMWeight( + weight_name=f"{prefix}.wq_b.weight", + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="wq_b", + tp_rank=0, + tp_world_size=1, + ) + self.wk_proj_ = ROWMMWeight( + weight_name=f"{prefix}.wk.weight", + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="wk", + tp_rank=0, + tp_world_size=1, + ) + self.k_norm_ = NormWeight( + f"{prefix}.k_norm.weight", + torch.float32, + bias_name=f"{prefix}.k_norm.bias" + ) + self.weights_proj_ = ROWMMWeight( + weight_name=f"{prefix}.weights_proj.weight", + data_type=self.data_type_, + quant_cfg=None, + layer_num=self.layer_num_, + name="weights_proj", + tp_rank=0, + tp_world_size=1, + ) diff --git a/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..2a03e1d6a --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py @@ -0,0 +1,16 @@ +from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight +from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight + + +class Deepseek3_2TransformerLayerWeight(Deepseek2TransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + self.index_topk = network_config["index_topk"] + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + self.indexer_layer_weight = NSAIndexerWeight( + layer_num=layer_num, + data_type=data_type, + network_config=network_config, + mode=mode, + quant_cfg=quant_cfg + ) + return diff --git a/lightllm/models/deepseek3_2/mem_manager.py b/lightllm/models/deepseek3_2/mem_manager.py new file mode 100644 index 000000000..a70c76273 --- /dev/null +++ b/lightllm/models/deepseek3_2/mem_manager.py @@ -0,0 +1,22 @@ +from typing import List +from typing_extensions import override +import torch + +from lightllm.common.mem_manager import MemoryManager +from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager +from lightllm.server.pd_io_struct import KVMoveTask +from lightllm.distributed.pynccl import PyNcclCommunicator + +class Deepseek3_2MemoryManager(Deepseek2MemoryManager): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9 ,is_sub_mem_manager=False): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_sub_mem_manager) + self.indexer_ks_mem_manager = Deepseek2MemoryManager(self.size, torch.uint8, 1, 132, layer_num, is_sub_mem_manager=True) + return + + @override + def get_cell_size(self): + return super().get_cell_size() + 132 + +class Deepseek3_2FP8KVMemoryManager(Deepseek3_2MemoryManager): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False): + super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction, is_sub_mem_manager) \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py new file mode 100644 index 000000000..8f1ba85cf --- /dev/null +++ b/lightllm/models/deepseek3_2/model.py @@ -0,0 +1,47 @@ +from lightllm.models.registry import ModelRegistry +from lightllm.models.deepseek2.model import Deepseek2TpPartModel +from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight +from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo +from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager, Deepseek3_2FP8KVMemoryManager +@ModelRegistry(["deepseek_v32"]) +class Deepseek3_2TpPartModel(Deepseek2TpPartModel): + # weight class + transformer_weight_class = Deepseek3_2TransformerLayerWeight + + # infer class + transformer_layer_infer_class = Deepseek3_2TransformerLayerInfer + + # infer state class + infer_state_class = Deepseek3_2FlashAttentionStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + self.index_topk = self.config["index_topk"] + return + + def _init_inferstate_cls(self): + self.infer_state_class = Deepseek3_2FlashAttentionStateInfo + + def _init_mem_manager(self): + manager_class = Deepseek3_2MemoryManager + if "triton_fp8kv" in self.mode: + manager_class = Deepseek3_2FP8KVMemoryManager + + # mtp 模式下需要在mem manger上扩展draft model使用的layer + added_mtp_layer_num = 0 + if get_env_start_args().mtp_mode == "deepseekv3_eagle": + added_mtp_layer_num += 1 + elif get_env_start_args().mtp_mode == "deepseekv3_vanilla": + added_mtp_layer_num += get_env_start_args().mtp_step + + self.mem_manager = manager_class( + self.max_total_token_num, + dtype=self.data_type, + head_num=1, + head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"], + layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, + mem_fraction=self.mem_fraction, + ) + return \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/triton_kernel/__init__.py b/lightllm/models/deepseek3_2/triton_kernel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/deepseek3_2/triton_kernel/act_quant.py b/lightllm/models/deepseek3_2/triton_kernel/act_quant.py new file mode 100644 index 000000000..a4ecd0f51 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/act_quant.py @@ -0,0 +1,137 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/ce6b17c0f94e6bf53633c8f324176a891e67fa7f/python/sglang/srt/layers/attention/nsa/triton_kernel.py +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + + +# Triton implementation +@triton.jit +def _act_quant_kernel( + X_ptr, + Y_ptr, + S_ptr, + M, + N, + group_size: tl.constexpr, + round_scale: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + Triton kernel for activation quantization. + + Each block processes BLOCK_M rows and group_size columns. + """ + # Get block IDs + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + # FP8 constants + fp8_min = -448.0 + fp8_max = 448.0 + fp8_max_inv = 1.0 / fp8_max + + # Calculate row and column offsets + row_start = pid_m * BLOCK_M + col_start = pid_n * group_size + + # Create offset arrays + rows = row_start + tl.arange(0, BLOCK_M) + cols = col_start + tl.arange(0, BLOCK_N) + + # Mask for valid rows and columns + row_mask = rows < M + col_mask = cols < N + mask = row_mask[:, None] & col_mask[None, :] + + # Load input data + x_ptrs = X_ptr + rows[:, None] * N + cols[None, :] + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + + # Compute absolute max along columns (group_size dimension) for each row + x_abs = tl.abs(x) + amax = tl.max(x_abs, axis=1) # Shape: (BLOCK_M,) + + # Clamp amax to avoid division by zero + amax = tl.maximum(amax, 1e-4) + + # Compute scale + if round_scale: + # Fast round scale using bit manipulation approximation + # This is a simplified version - the exact bit manipulation is harder in Triton + # Using log2 + ceil + pow2 as approximation + log_val = tl.log2(amax * fp8_max_inv) + log_ceil = tl.ceil(log_val) + scale = tl.exp2(log_ceil) + else: + scale = amax * fp8_max_inv + + # Quantize: y = clamp(x / scale, fp8_min, fp8_max) + scale_broadcast = scale[:, None] + y = x / scale_broadcast + y = tl.minimum(tl.maximum(y, fp8_min), fp8_max) + + # Store quantized output + y_ptrs = Y_ptr + rows[:, None] * N + cols[None, :] + tl.store(y_ptrs, y, mask=mask) + + # Store scales + s_cols = pid_n + s_ptrs = S_ptr + rows * (N // group_size) + s_cols + s_mask = row_mask + tl.store(s_ptrs, scale, mask=s_mask) + + +def act_quant( + x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization with Triton. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + scale_fmt (Optional[str], optional): The format of the scale. Default is None. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert ( + x.size(-1) % block_size == 0 + ), f"Last dimension size must be divisible by block_size (block_size={block_size})" + + # Flatten all dims except last + N = x.size(-1) + x_flat = x.view(-1, N) + M = x_flat.size(0) + + # Allocate output tensors + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + y_flat = y.view(-1, N) + s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32) + s_flat = s.view(-1, N // block_size) + + # Launch kernel + BLOCK_M = 32 + BLOCK_N = block_size + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, block_size)) + round_scale = scale_fmt is not None + + _act_quant_kernel[grid]( + x_flat, + y_flat, + s_flat, + M, + N, + group_size=block_size, + round_scale=round_scale, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_stages=0 if round_scale else 2, + ) + + return y, s diff --git a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py new file mode 100644 index 000000000..46095bfb7 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py @@ -0,0 +1,275 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_destindex_copy_indexer_ks( + K_fp8, + K_scale, + DestLoc, + O_buffer, + stride_k_bs, + stride_k_d, + stride_scale_bs, + stride_scale_d, + stride_o_bs, + stride_o_h, + stride_o_d, + BLOCK_DMODEL: tl.constexpr, +): + """ + Triton kernel to copy FP8 K values and their scales to an indexed output buffer. + + This kernel reads FP8 key values (128 dims) and their float32 scale values, + then writes them to a compact buffer format where each entry contains: + - Bytes 0-127: FP8 key values (128 bytes) + - Bytes 128-131: Float32 scale (4 bytes) + + The destination location for each source element is specified by DestLoc. + """ + cur_index = tl.program_id(0) + offs_d = tl.arange(0, BLOCK_DMODEL) + + # Load destination index for this thread + dest_index = tl.load(DestLoc + cur_index).to(tl.int64) + + # Load K_fp8 (128 values) and K_scale (1 value) from source + k_fp8_ptrs = K_fp8 + cur_index * stride_k_bs + stride_k_d * offs_d + k_fp8 = tl.load(k_fp8_ptrs) + + k_scale = tl.load(K_scale + cur_index * stride_scale_bs) + + # Store K_fp8 to O_buffer[:, 0, :128] + # Convert fp8 to uint8 through bitcast for storage in uint8 buffer + o_k_ptrs = O_buffer + dest_index * stride_o_bs + stride_o_d * offs_d + k_fp8_as_uint8 = k_fp8.to(tl.uint8, bitcast=True) + tl.store(o_k_ptrs, k_fp8_as_uint8) + + # Store K_scale to O_buffer[:, 0, 128:132] (4 bytes for float32) + # Convert float32 scale to 4 uint8 bytes using bitcast and bit manipulation + o_scale_ptr = O_buffer + dest_index * stride_o_bs + BLOCK_DMODEL * stride_o_d + scale_as_uint32 = k_scale.to(tl.float32, bitcast=True).to(tl.uint32, bitcast=True) + + # Store each byte of the float32 scale (little-endian) + for i in range(4): + byte_val = ((scale_as_uint32 >> (i * 8)) & 0xFF).to(tl.uint8) + tl.store(o_scale_ptr + i * stride_o_d, byte_val) + + return + + +@torch.no_grad() +def destindex_copy_indexer_ks(K_fp8: torch.Tensor, K_scale: torch.Tensor, DestLoc: torch.Tensor, O_buffer: torch.Tensor): + """ + Copy FP8-quantized key values and their scales to indexed locations in a buffer. + + This function is used in the DeepSeek-V3.2 NSA (Neighbor-aware Sparse Attention) + mechanism to store compressed key representations in a memory buffer. Each key + is stored with its FP8 representation (128 bytes) followed by its float32 scale + (4 bytes), for a total of 132 bytes per key. + + Args: + K_fp8: [q_seq_len, 128] torch.fp8_e4m3fn + FP8-quantized key values + K_scale: [q_seq_len, 1] torch.float32 + Quantization scales for each key + DestLoc: [q_seq_len] torch.int32 + Destination indices in the output buffer + O_buffer: [large_size, 1, 132] torch.uint8 + Output buffer where keys and scales will be written. + Must be a uint8 tensor to allow mixed-type storage. + Format: [:, 0, :128] = FP8 keys, [:, 0, 128:132] = float32 scales + + Returns: + None (modifies O_buffer in-place) + + Example: + >>> k_fp8 = torch.randn(50, 128).to(torch.float8_e4m3fn).cuda() + >>> k_scale = torch.randn(50, 1).cuda() + >>> dest_loc = torch.randint(0, 1024, (50,), dtype=torch.int32).cuda() + >>> o_buffer = torch.zeros(1024, 1, 132, dtype=torch.uint8).cuda() + >>> destindex_copy_indexer_ks(k_fp8, k_scale, dest_loc, o_buffer) + >>> # Now o_buffer[dest_loc] contains the packed k_fp8 and k_scale data + """ + seq_len = DestLoc.shape[0] + head_dim = K_fp8.shape[1] + + assert head_dim == 128, f"Expected head_dim=128, got {head_dim}" + assert K_scale.shape[0] == seq_len + assert O_buffer.shape[2] == 132, f"Expected O_buffer last dim=132, got {O_buffer.shape[2]}" + + grid = (seq_len,) + num_warps = 1 + + _fwd_kernel_destindex_copy_indexer_ks[grid]( + K_fp8, + K_scale, + DestLoc, + O_buffer, + K_fp8.stride(0), + K_fp8.stride(1), + K_scale.stride(0), + K_scale.stride(1), + O_buffer.stride(0), + O_buffer.stride(1), + O_buffer.stride(2), + BLOCK_DMODEL=head_dim, + num_warps=num_warps, + num_stages=1, + ) + return + + +def test_destindex_copy_indexer_ks(): + """Test the destindex_copy_indexer_ks kernel""" + import torch.nn.functional as F + + print("=" * 80) + print("Testing destindex_copy_indexer_ks") + print("=" * 80) + + # Test parameters + q_seq_len = 50 + head_dim = 128 + large_size = 1024 + dtype = torch.bfloat16 + fp8_type = torch.float8_e4m3fn + + # Create random destination indices + dest_loc = torch.randint(0, large_size, (q_seq_len,), device="cuda", dtype=torch.int32).unique() + actual_seq_len = len(dest_loc) + + # Create input tensors + k_bf16 = torch.randn((actual_seq_len, head_dim), dtype=dtype, device="cuda") + + # Quantize to FP8 + k_abs_max = k_bf16.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale = (k_abs_max / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8 = (k_bf16 / k_abs_max).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + # Create output buffer (as uint8 to allow reinterpretation) + o_buffer_uint8 = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + + # Run kernel + destindex_copy_indexer_ks(k_fp8, k_scale, dest_loc, o_buffer_uint8) + + # Extract results + k_fp8_out = o_buffer_uint8[:, 0, :128].view(fp8_type) + + # Extract scale by reinterpreting 4 bytes as float32 + scale_bytes = o_buffer_uint8[:, 0, 128:132].contiguous() + k_scale_out = scale_bytes.view(-1, 4).view(torch.float32).squeeze(-1) + + # Verify results at destination locations + k_fp8_extracted = k_fp8_out[dest_loc] + k_scale_extracted = k_scale_out[dest_loc] + + # Check FP8 values match + fp8_match = torch.allclose( + k_fp8_extracted.to(torch.float32), + k_fp8.to(torch.float32), + atol=0, rtol=0 + ) + + # Check scales match + scale_match = torch.allclose( + k_scale_extracted, + k_scale.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + # Check dequantized values + k_dequant_out = k_fp8_extracted.to(dtype) * k_scale_extracted.unsqueeze(-1) + cosine_sim = F.cosine_similarity(k_dequant_out, k_bf16, dim=-1).mean() + + print(f"Test with seq_len={actual_seq_len}, head_dim={head_dim}") + print(f" FP8 values match: {fp8_match}") + print(f" Scale values match: {scale_match}") + print(f" Cosine similarity after dequantization: {cosine_sim:.6f}") + + assert fp8_match, "FP8 values do not match!" + assert scale_match, "Scale values do not match!" + assert cosine_sim > 0.99, f"Cosine similarity too low: {cosine_sim}" + + print("✓ Basic test passed!") + print() + + # Test edge cases + print("Testing edge cases...") + + # Test with sequential indices + dest_loc_seq = torch.arange(20, device="cuda", dtype=torch.int32) + k_bf16_seq = torch.randn((20, head_dim), dtype=dtype, device="cuda") + k_abs_max_seq = k_bf16_seq.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_seq = (k_abs_max_seq / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_seq = (k_bf16_seq / k_abs_max_seq).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + o_buffer_seq = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_seq, k_scale_seq, dest_loc_seq, o_buffer_seq) + + k_fp8_out_seq = o_buffer_seq[:20, 0, :128].view(fp8_type) + scale_bytes_seq = o_buffer_seq[:20, 0, 128:132].contiguous() + k_scale_out_seq = scale_bytes_seq.view(-1, 4).view(torch.float32).squeeze(-1) + + fp8_match_seq = torch.allclose( + k_fp8_out_seq.to(torch.float32), + k_fp8_seq.to(torch.float32), + atol=0, rtol=0 + ) + scale_match_seq = torch.allclose( + k_scale_out_seq, + k_scale_seq.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + print(f" Sequential indices test: FP8={fp8_match_seq}, Scale={scale_match_seq}") + assert fp8_match_seq and scale_match_seq + print("✓ Edge case tests passed!") + print() + + # Test with single element + print("Testing single element...") + dest_loc_single = torch.tensor([42], device="cuda", dtype=torch.int32) + k_bf16_single = torch.randn((1, head_dim), dtype=dtype, device="cuda") + k_abs_max_single = k_bf16_single.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_single = (k_abs_max_single / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_single = (k_bf16_single / k_abs_max_single).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + o_buffer_single = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_single, k_scale_single, dest_loc_single, o_buffer_single) + + k_fp8_out_single = o_buffer_single[42:43, 0, :128].view(fp8_type) + scale_bytes_single = o_buffer_single[42:43, 0, 128:132].contiguous() + k_scale_out_single = scale_bytes_single.view(-1, 4).view(torch.float32).squeeze(-1) + + fp8_match_single = torch.allclose( + k_fp8_out_single.to(torch.float32), + k_fp8_single.to(torch.float32), + atol=0, rtol=0 + ) + scale_match_single = torch.allclose( + k_scale_out_single, + k_scale_single.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + print(f" Single element test: FP8={fp8_match_single}, Scale={scale_match_single}") + assert fp8_match_single and scale_match_single + print("✓ Single element test passed!") + print() + + print("=" * 80) + print("All tests passed successfully! ✓") + print("=" * 80) + + +if __name__ == "__main__": + test_destindex_copy_indexer_ks() \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py new file mode 100644 index 000000000..eb22fbb8f --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py @@ -0,0 +1,309 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_extract_indexer_ks( + I_buffer, # Input buffer [large_size, 1, 132] uint8 + SrcLoc, # Source indices [req_size] int32/int64 + O_fp8, # Output FP8 [req_size, 128] float8_e4m3fn + O_scale, # Output scale [req_size] float32 + stride_i_bs, + stride_i_h, + stride_i_d, + stride_o_fp8_bs, + stride_o_fp8_d, + stride_o_scale_bs, + BLOCK_DMODEL: tl.constexpr, +): + """ + Triton kernel to extract FP8 K values and their scales from an indexed buffer. + + This kernel is the inverse of destindex_copy_indexer_ks. It reads from a + compact buffer format where each entry contains: + - Bytes 0-127: FP8 key values (128 bytes) + - Bytes 128-131: Float32 scale (4 bytes) + + The source location for each output element is specified by SrcLoc. + """ + cur_index = tl.program_id(0) + offs_d = tl.arange(0, BLOCK_DMODEL) + + # Load source index for this thread + src_index = tl.load(SrcLoc + cur_index).to(tl.int64) + + # Load K_fp8 from I_buffer[:, 0, :128] + i_k_ptrs = I_buffer + src_index * stride_i_bs + stride_i_d * offs_d + k_fp8_as_uint8 = tl.load(i_k_ptrs) + + # Convert uint8 to fp8 through bitcast + k_fp8 = k_fp8_as_uint8.to(tl.float8e4nv, bitcast=True) + + # Store K_fp8 to output + o_k_ptrs = O_fp8 + cur_index * stride_o_fp8_bs + stride_o_fp8_d * offs_d + tl.store(o_k_ptrs, k_fp8) + + # Load K_scale from I_buffer[:, 0, 128:132] (4 bytes for float32) + # Load 4 bytes and reconstruct float32 (little-endian) + i_scale_base_ptr = I_buffer + src_index * stride_i_bs + BLOCK_DMODEL * stride_i_d + + # Load 4 bytes individually and combine them into uint32 + byte0 = tl.load(i_scale_base_ptr + 0 * stride_i_d).to(tl.uint32) + byte1 = tl.load(i_scale_base_ptr + 1 * stride_i_d).to(tl.uint32) + byte2 = tl.load(i_scale_base_ptr + 2 * stride_i_d).to(tl.uint32) + byte3 = tl.load(i_scale_base_ptr + 3 * stride_i_d).to(tl.uint32) + + # Combine bytes into uint32 (little-endian: byte0 is LSB) + scale_as_uint32 = byte0 | (byte1 << 8) | (byte2 << 16) | (byte3 << 24) + + # Bitcast uint32 to float32 + k_scale = scale_as_uint32.to(tl.float32, bitcast=True) + + # Store scale to output + o_scale_ptr = O_scale + cur_index * stride_o_scale_bs + tl.store(o_scale_ptr, k_scale) + + return + + +@torch.no_grad() +def extract_indexer_ks(I_buffer: torch.Tensor, SrcLoc: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Extract FP8-quantized key values and their scales from indexed locations in a buffer. + + This function is the inverse operation of destindex_copy_indexer_ks. It's used in + the DeepSeek-V3.2 NSA (Neighbor-aware Sparse Attention) mechanism to retrieve + compressed key representations from a memory buffer. + + Args: + I_buffer: [large_size, 1, 132] torch.uint8 + Input buffer containing packed FP8 keys and float32 scales. + Format: [:, 0, :128] = FP8 keys, [:, 0, 128:132] = float32 scales + SrcLoc: [req_size] torch.int32 or torch.int64 + Source indices to extract from the input buffer + + Returns: + tuple containing: + - K_fp8: [req_size, 128] torch.float8_e4m3fn + FP8-quantized key values + - K_scale: [req_size] torch.float32 + Quantization scales for each key + + Example: + >>> i_buffer = torch.zeros(1024, 1, 132, dtype=torch.uint8).cuda() + >>> src_loc = torch.tensor([10, 20, 30], dtype=torch.int32).cuda() + >>> k_fp8, k_scale = extract_indexer_ks(i_buffer, src_loc) + >>> # k_fp8.shape == [3, 128], k_scale.shape == [3] + """ + req_size = SrcLoc.shape[0] + head_dim = 128 + + assert I_buffer.dtype == torch.uint8, f"Expected I_buffer dtype=uint8, got {I_buffer.dtype}" + assert I_buffer.shape[2] == 132, f"Expected I_buffer last dim=132, got {I_buffer.shape[2]}" + + # Allocate output tensors + O_fp8 = torch.empty((req_size, head_dim), dtype=torch.float8_e4m3fn, device=I_buffer.device) + O_scale = torch.empty((req_size,), dtype=torch.float32, device=I_buffer.device) + + grid = (req_size,) + num_warps = 1 + + _fwd_kernel_extract_indexer_ks[grid]( + I_buffer, + SrcLoc, + O_fp8, + O_scale, + I_buffer.stride(0), + I_buffer.stride(1), + I_buffer.stride(2), + O_fp8.stride(0), + O_fp8.stride(1), + O_scale.stride(0), + BLOCK_DMODEL=head_dim, + num_warps=num_warps, + num_stages=1, + ) + + return O_fp8, O_scale + + +def test_extract_indexer_ks(): + """Test the extract_indexer_ks kernel against the copy kernel""" + import torch.nn.functional as F + from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks + + print("=" * 80) + print("Testing extract_indexer_ks") + print("=" * 80) + + # Test parameters + q_seq_len = 50 + head_dim = 128 + large_size = 1024 + dtype = torch.bfloat16 + fp8_type = torch.float8_e4m3fn + + # Create random indices for writing + write_indices = torch.randint(0, large_size, (q_seq_len,), device="cuda", dtype=torch.int32).unique() + actual_seq_len = len(write_indices) + + # Create input tensors + k_bf16_original = torch.randn((actual_seq_len, head_dim), dtype=dtype, device="cuda") + + # Quantize to FP8 + k_abs_max = k_bf16_original.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_original = (k_abs_max / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_original = (k_bf16_original / k_abs_max).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + # Create buffer and write data using destindex_copy_indexer_ks + buffer = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_original, k_scale_original, write_indices, buffer) + + # Now extract the data back using extract_indexer_ks + k_fp8_extracted, k_scale_extracted = extract_indexer_ks(buffer, write_indices) + + # Verify FP8 values match + fp8_match = torch.allclose( + k_fp8_extracted.to(torch.float32), + k_fp8_original.to(torch.float32), + atol=0, rtol=0 + ) + + # Verify scales match + scale_match = torch.allclose( + k_scale_extracted, + k_scale_original.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + # Check dequantized values + k_dequant_extracted = k_fp8_extracted.to(dtype) * k_scale_extracted.unsqueeze(-1) + cosine_sim = F.cosine_similarity(k_dequant_extracted, k_bf16_original, dim=-1).mean() + + print(f"Test with seq_len={actual_seq_len}, head_dim={head_dim}") + print(f" FP8 values match: {fp8_match}") + print(f" Scale values match: {scale_match}") + print(f" Cosine similarity after dequantization: {cosine_sim:.6f}") + + assert fp8_match, "FP8 values do not match!" + assert scale_match, "Scale values do not match!" + assert cosine_sim > 0.99, f"Cosine similarity too low: {cosine_sim}" + + print("✓ Basic test passed!") + print() + + # Test with sequential indices + print("Testing sequential indices...") + write_indices_seq = torch.arange(20, device="cuda", dtype=torch.int32) + k_bf16_seq = torch.randn((20, head_dim), dtype=dtype, device="cuda") + k_abs_max_seq = k_bf16_seq.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_seq = (k_abs_max_seq / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_seq = (k_bf16_seq / k_abs_max_seq).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + buffer_seq = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_seq, k_scale_seq, write_indices_seq, buffer_seq) + k_fp8_ext_seq, k_scale_ext_seq = extract_indexer_ks(buffer_seq, write_indices_seq) + + fp8_match_seq = torch.allclose( + k_fp8_ext_seq.to(torch.float32), + k_fp8_seq.to(torch.float32), + atol=0, rtol=0 + ) + scale_match_seq = torch.allclose( + k_scale_ext_seq, + k_scale_seq.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + print(f" Sequential indices: FP8={fp8_match_seq}, Scale={scale_match_seq}") + assert fp8_match_seq and scale_match_seq + print("✓ Sequential test passed!") + print() + + # Test with single element + print("Testing single element...") + write_idx_single = torch.tensor([42], device="cuda", dtype=torch.int32) + k_bf16_single = torch.randn((1, head_dim), dtype=dtype, device="cuda") + k_abs_max_single = k_bf16_single.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_single = (k_abs_max_single / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_single = (k_bf16_single / k_abs_max_single).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + buffer_single = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_single, k_scale_single, write_idx_single, buffer_single) + k_fp8_ext_single, k_scale_ext_single = extract_indexer_ks(buffer_single, write_idx_single) + + fp8_match_single = torch.allclose( + k_fp8_ext_single.to(torch.float32), + k_fp8_single.to(torch.float32), + atol=0, rtol=0 + ) + scale_match_single = torch.allclose( + k_scale_ext_single, + k_scale_single.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + print(f" Single element: FP8={fp8_match_single}, Scale={scale_match_single}") + assert fp8_match_single and scale_match_single + print("✓ Single element test passed!") + print() + + # Test with larger batch to check performance characteristics + print("Testing larger batch (performance check)...") + write_indices_large = torch.randint(0, large_size * 10, (500,), device="cuda", dtype=torch.int32).unique() + actual_large_len = len(write_indices_large) + k_bf16_large = torch.randn((actual_large_len, head_dim), dtype=dtype, device="cuda") + k_abs_max_large = k_bf16_large.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_large = (k_abs_max_large / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_large = (k_bf16_large / k_abs_max_large).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + buffer_large = torch.zeros((large_size * 10, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_large, k_scale_large, write_indices_large, buffer_large) + + # Warm up + for _ in range(3): + _ = extract_indexer_ks(buffer_large, write_indices_large) + + # Time it + torch.cuda.synchronize() + import time + start = time.time() + for _ in range(100): + k_fp8_ext_large, k_scale_ext_large = extract_indexer_ks(buffer_large, write_indices_large) + torch.cuda.synchronize() + elapsed = time.time() - start + + fp8_match_large = torch.allclose( + k_fp8_ext_large.to(torch.float32), + k_fp8_large.to(torch.float32), + atol=0, rtol=0 + ) + scale_match_large = torch.allclose( + k_scale_ext_large, + k_scale_large.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + print(f" Large batch (size={actual_large_len}): FP8={fp8_match_large}, Scale={scale_match_large}") + print(f" Average time per call: {elapsed/100*1000:.3f} ms") + assert fp8_match_large and scale_match_large + print("✓ Large batch test passed!") + print() + + print("=" * 80) + print("All tests passed successfully! ✓") + print("=" * 80) + + +if __name__ == "__main__": + test_extract_indexer_ks() diff --git a/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py new file mode 100644 index 000000000..2fc92662a --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py @@ -0,0 +1,139 @@ +import triton +import triton.language as tl +import torch + + +@triton.jit +def _fp8_paged_mqa_logits_kernel( + Q_ptr, KV_ptr, KVScale_ptr, Weights_ptr, MemIndex_ptr, + CuSeqlenKs_ptr, CuSeqlenKe_ptr, Output_ptr, + seq_len, seq_len_kv, num_heads, head_dim, + stride_q_seq, stride_q_head, stride_q_dim, + stride_kv_pool, stride_kv_dim, + stride_w_seq, stride_w_head, + stride_o_seq, stride_o_kv, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + # Compute the range of seq positions this block handles + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + + # Offset arrays for this block + offs_m = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_n = start_n + tl.arange(0, BLOCK_SIZE_N) + + # Initialize accumulator for logits + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Create masks + mask_m = offs_m < seq_len + mask_n = offs_n < seq_len_kv + + # Load mem_indices for the KV positions + mem_indices = tl.load(MemIndex_ptr + offs_n, mask=mask_n, other=0) + + # Load scales for K + scales = tl.load(KVScale_ptr + mem_indices, mask=mask_n, other=1.0) + + # Loop over all heads + for h in range(num_heads): + # Load weights for this head + weights = tl.load(Weights_ptr + offs_m * stride_w_seq + h * stride_w_head, + mask=mask_m, other=0.0) + + # Initialize score accumulator for this head + score = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Loop over head_dim in blocks + for d_block in range(tl.cdiv(head_dim, BLOCK_SIZE_D)): + d_start = d_block * BLOCK_SIZE_D + offs_d = d_start + tl.arange(0, BLOCK_SIZE_D) + mask_d = offs_d < head_dim + + # Load Q for this head and dimension block + # Q shape: (seq_len, num_heads, head_dim) + q_ptrs = Q_ptr + offs_m[:, None] * stride_q_seq + h * stride_q_head + offs_d[None, :] * stride_q_dim + mask_q = (offs_m[:, None] < seq_len) & mask_d[None, :] + q = tl.load(q_ptrs, mask=mask_q, other=0.0).to(tl.float32) + + # Load K for this dimension block + # KV shape: (pool_size, head_dim) as FP8 data + k_ptrs = KV_ptr + mem_indices[:, None] * stride_kv_pool + offs_d[None, :] * stride_kv_dim + mask_k = mask_n[:, None] & mask_d[None, :] + k = tl.load(k_ptrs, mask=mask_k, other=0.0).to(tl.float32) + + # Apply scale to K (scale is per-row of K) + k = k * scales[:, None] + + # Compute partial dot product: q @ k.T + # q: (BLOCK_SIZE_M, BLOCK_SIZE_D), k: (BLOCK_SIZE_N, BLOCK_SIZE_D) + # score: (BLOCK_SIZE_M, BLOCK_SIZE_N) + score += tl.dot(q, tl.trans(k)) + + # Apply ReLU to score + score = tl.maximum(score, 0.0) + + # Multiply by weights and accumulate to logits + logits += score * weights[:, None] + + # Apply mask based on cu_seqlen_ks and cu_seqlen_ke + mask_ks = tl.load(CuSeqlenKs_ptr + offs_m, mask=mask_m, other=0) + mask_ke = tl.load(CuSeqlenKe_ptr + offs_m, mask=mask_m, other=seq_len_kv) + + mask_lo = offs_n[None, :] >= mask_ks[:, None] + mask_hi = offs_n[None, :] < mask_ke[:, None] + mask_valid = mask_lo & mask_hi & mask_m[:, None] & mask_n[None, :] + + # Apply mask (-inf for masked positions) + logits = tl.where(mask_valid, logits, float('-inf')) + + # Store output + out_ptrs = Output_ptr + offs_m[:, None] * stride_o_seq + offs_n[None, :] * stride_o_kv + mask_out = (offs_m[:, None] < seq_len) & (offs_n[None, :] < seq_len_kv) + tl.store(out_ptrs, logits, mask=mask_out) + + +def fp8_paged_mqa_logits( + q: torch.Tensor, + kv: torch.Tensor, + kv_scale: torch.Tensor, + weights: torch.Tensor, + mem_index: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + out: torch.Tensor = None +) -> torch.Tensor: + seq_len, num_heads, head_dim = q.shape + seq_len_kv = mem_index.shape[0] + + if out is None: + output = torch.empty((seq_len, seq_len_kv), device=q.device, dtype=torch.float32) + else: + output = out + + BLOCK_SIZE_M = 16 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_D = 128 + + grid = (triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(seq_len_kv, BLOCK_SIZE_N)) + + _fp8_paged_mqa_logits_kernel[grid]( + q, kv, kv_scale, weights, mem_index, + cu_seqlen_ks, cu_seqlen_ke, output, + seq_len, seq_len_kv, num_heads, head_dim, + q.stride(0), q.stride(1), q.stride(2), + kv.stride(0), kv.stride(1), + weights.stride(0), weights.stride(1), + output.stride(0), output.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_D=BLOCK_SIZE_D, + ) + + return output \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py b/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py new file mode 100644 index 000000000..dbf5c5199 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py @@ -0,0 +1,103 @@ +import triton +import triton.language as tl +import torch +from typing import Tuple + +fp8_min = -448.0 +fp8_max = 448.0 +fp8_dtype = torch.float8_e4m3fn + +@triton.jit +def _per_token_group_quant_mla_deep_gemm_masked_fp8( + y_ptr, + y_q_ptr, + y_s_ptr, + masked_m_ptr, + group_size, + y_stride_b, + y_stride_t, + y_q_stride_b, + y_q_stride_t, + y_s_stride_b, + y_s_stride_g, + eps, + fp8_min, + fp8_max, + NUM_GROUP: tl.constexpr, + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor for deep_gemm grouped_gemm_masked. + This function converts the tensor values into float8 values. + y and y_q: (b, t, k) + y_s: (b, k//group_size, t) + """ + t_id = tl.program_id(0) + b_id = tl.program_id(1) + + y_ptr += b_id * y_stride_b + t_id * y_stride_t + y_q_ptr += b_id * y_q_stride_b + t_id * y_q_stride_t + y_s_ptr += b_id * y_s_stride_b + t_id + + if t_id == 0: + tl.store(masked_m_ptr + b_id, tl.num_programs(0)) + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + for gid in range(NUM_GROUP): + y = tl.load(y_ptr + gid * group_size + cols, mask=mask, other=0.0).to( + tl.float32 + ) + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + gid * group_size + cols, y_q, mask=mask) + tl.store(y_s_ptr + gid * y_s_stride_g, y_s) + + +def per_token_group_quant_mla_deep_gemm_masked_fp8( + x: torch.Tensor, + group_size: int = 128, + eps: float = 1e-12, + dtype: torch.dtype = fp8_dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This function quantizes input values to float8 values with per-token-group-quantization + for deep_gemm grouped_gemm_masked and specialized for mla absorbed case. + """ + assert x.dim() == 3, "`x` is not a 3d-tensor" + + b, m, k = x.shape + aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel + num_tiles_k = k // group_size + assert num_tiles_k * group_size == k, f"k % {group_size} must be zero" + + x_q = x.new_empty((b, aligned_m, k), dtype=dtype) + x_s = x.new_empty((b, num_tiles_k, aligned_m), dtype=torch.float32) + masked_m = x.new_empty((b,), dtype=torch.int32) + + BLOCK_SIZE = triton.next_power_of_2(group_size) + grid = (m, b) + + _per_token_group_quant_mla_deep_gemm_masked_fp8[grid]( + x, + x_q, + x_s, + masked_m, + group_size, + x.stride(0), + x.stride(1), + x_q.stride(0), + x_q.stride(1), + x_s.stride(0), + x_s.stride(1), + eps, + -fp8_max, + fp8_max, + num_tiles_k, + BLOCK_SIZE, + ) + + return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m \ No newline at end of file