Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(self, kvargs):
self._init_some_value()
self._init_custom()
self._init_inferstate_cls()
self._autotune_warmup()
# self._autotune_warmup()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The call to _autotune_warmup() has been commented out. If this was for debugging, it should be removed. If autotuning is intentionally disabled for this model, it would be better to control this with a configuration flag for clarity and to avoid accidental performance degradation.

Suggested change
# self._autotune_warmup()
self._autotune_warmup()

self._init_padded_req()
# wait必须在init cudagraph 之前,避免错误捕获
self._wait_other_modules_ready()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,17 @@ def load_hf_weights(self, weights):
"""
for attr_name in dir(self):
attr = getattr(self, attr_name, None)
if isinstance(attr, MultiMMWeightTpl):
if isinstance(attr, TransformerLayerWeight):
attr.load_hf_weights(weights)
elif isinstance(attr, MultiMMWeightTpl):
with self.lock:
attr.load_hf_weights(weights)
elif isinstance(attr, BaseWeight):
attr.load_hf_weights(weights)

def verify_load(self):
for attr_name in dir(self):
attr = getattr(self, attr_name, None)
if isinstance(attr, TransformerLayerWeight):
attr.verify_load()
super().verify_load()
4 changes: 2 additions & 2 deletions lightllm/common/deepseek2_fp8kv_mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@


class Deepseek2FP8KVMemoryManager(Deepseek2MemoryManager):
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False):
# scale被追加到kv_buffer末尾, 因此加2, dtype统一改成uint8
super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction)
super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction, is_sub_mem_manager)
4 changes: 2 additions & 2 deletions lightllm/common/deepseek2_mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@


class Deepseek2MemoryManager(MemoryManager):
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction)
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False):
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_sub_mem_manager)

def get_cell_size(self):
return self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype)
Expand Down
17 changes: 9 additions & 8 deletions lightllm/common/mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


class MemoryManager:
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False):
self.size = size
self.head_num = head_num
self.head_dim = head_dim
Expand All @@ -41,15 +41,16 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False

self.can_use_mem_size = self.size

# 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。
from lightllm.utils.envs_utils import get_unique_server_name
if not is_sub_mem_manager:
# 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。
from lightllm.utils.envs_utils import get_unique_server_name

rank_in_node = get_current_rank_in_node()
self.shared_can_use_token_num = SharedInt(
f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}"
)
rank_in_node = get_current_rank_in_node()
self.shared_can_use_token_num = SharedInt(
f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}"
)

self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self._init_buffers(
self.size,
dtype,
Expand Down
2 changes: 1 addition & 1 deletion lightllm/common/triton_utils/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def autotune(
as needed before invocation.
"""

def decorator(fn):
def decorator(fn: Callable) -> Callable:
return Autotuner(
fn=fn,
kernel_name=kernel_name,
Expand Down
1 change: 1 addition & 0 deletions lightllm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from lightllm.models.gemma_2b.model import Gemma_2bTpPartModel
from lightllm.models.phi3.model import Phi3TpPartModel
from lightllm.models.deepseek2.model import Deepseek2TpPartModel
from lightllm.models.deepseek3_2.model import Deepseek3_2TpPartModel
from lightllm.models.internvl.model import (
InternVLLlamaTpPartModel,
InternVLPhi3TpPartModel,
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/deepseek2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, model):
self.softmax_scale = self.softmax_scale * mscale * mscale


@ModelRegistry(["deepseek_v2", "deepseek_v3"])
@ModelRegistry(["deepseek_v2", "deepseek_v3", "deepseek_v32"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The model name "deepseek_v32" is being registered to the Deepseek2TpPartModel class. This seems incorrect as it's for the new Deepseek3.2 model. The new Deepseek3_2TpPartModel class in lightllm/models/deepseek3_2/model.py should be registered for this model name instead. Please revert this change.

Suggested change
@ModelRegistry(["deepseek_v2", "deepseek_v3", "deepseek_v32"])
@ModelRegistry(["deepseek_v2", "deepseek_v3"])

class Deepseek2TpPartModel(LlamaTpPartModel):
# weight class
transformer_weight_class = Deepseek2TransformerLayerWeight
Expand Down
Empty file.
69 changes: 69 additions & 0 deletions lightllm/models/deepseek3_2/infer_struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo
from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager

class Deepseek3_2FlashAttentionStateInfo(Deepseek2FlashAttentionStateInfo):

def __init__(self):
super().__init__()
self.lengths = None
self.page_table_size_1 = None
self.ks = None
self.ke = None
self.nsa_cu_seqlens_k = None
self.index_topk = 2048
return

def init_some_extra_state(self, model, input_ids: torch.Tensor):
super().init_some_extra_state(model, input_ids)
assert isinstance(self.mem_manager, Deepseek3_2MemoryManager)
self.indexer_ks_mem_manager = self.mem_manager.indexer_ks_mem_manager

# Ensure b_ready_cache_len is set for both prefill and decode modes
if self.is_prefill:
# b_ready_cache_len is already set in basemodel.py for prefill
pass
else:
# In decode mode, b_ready_cache_len should be b_seq_len - b_q_seq_len
# since b_q_seq_len represents the new tokens being processed
if self.b_ready_cache_len is None:
self.b_ready_cache_len = self.b_seq_len - self.b_q_seq_len

self.nsa_cache_seqlens = self.b_att_seq_len.clamp(max=self.index_topk)
assert self.nsa_cache_seqlens.dtype == torch.int32
self.nsa_cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32), (1, 0)
)

# Pre-compute NSA indexer indexing structures
self._init_nsa_indexing_structures()

def _init_nsa_indexing_structures(self):
"""Pre-compute ks, ke, lengths, and page_table_size_1 for NSA indexer"""
mem_index_list = []
ks_list = []
ke_list = []
lengths_list = []
offset = 0
num_seq_len = self.b_req_idx.shape[0]
self.page_table_size_1 = torch.zeros((num_seq_len, self.b_seq_len.max()), dtype=torch.int, device='cuda')

for i in range(num_seq_len):
seq_len = self.b_seq_len[i]
q_seq_len = self.b_q_seq_len[i]
mem_index = self.req_manager.req_to_token_indexs[i, :seq_len]
mem_index_list.append(mem_index)
self.page_table_size_1[i, :seq_len] = mem_index
ks = torch.zeros(q_seq_len, dtype=torch.int, device='cuda') + offset
ke = torch.arange(q_seq_len, dtype=torch.int, device='cuda') + offset + 1
ks_list.append(ks)
ke_list.append(ke)
lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda'))
offset += seq_len

self.mem_index = torch.cat(mem_index_list, dim=0)
# ks : [seq_len_q] 标志kv的起始位置
# ke : [seq_len_q] 标志kv的结束位置
self.ks = torch.cat(ks_list, dim=0)
self.ke = torch.cat(ke_list, dim=0)
self.lengths = torch.cat(lengths_list, dim=0)
136 changes: 136 additions & 0 deletions lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from sgl_kernel import fast_topk_transform_fused
import deep_gemm
import torch
import torch.nn.functional as F

from lightllm.common.basemodel.layer_infer.base_layer_infer import BaseLayerInfer
from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight
from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo
from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd
from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant
from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager
from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks
from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks
from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)

class NSAIndexerInfer(BaseLayerInfer):
def __init__(self, layer_idx, network_config, mode=[]):
super().__init__()
self.layer_idx_ = layer_idx
self.network_config_ = network_config
self.mode = mode
self.index_topk = network_config["index_topk"]
self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_
self.tp_k_head_num_ = 1
self.tp_v_head_num_ = 1
self.qk_nope_head_dim = network_config["qk_nope_head_dim"]
self.qk_rope_head_dim = network_config["qk_rope_head_dim"]
self.index_head_dim = network_config["index_head_dim"]
self.eps = network_config["rms_norm_eps"]
self.block_size = network_config["quantization_config"]["weight_block_size"][0]
self.scale_fmt = network_config["quantization_config"]["scale_fmt"]
self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5)
self.index_n_heads = network_config["index_n_heads"]
self.index_n_heads_scale = (self.index_n_heads ** -0.5) * self.softmax_scale

return

def ref_fp8_mqa_logits(self, q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, cost_only: bool = False):
seq_len_kv = kv.shape[0]

if cost_only:
start = cu_seqlen_ks.clamp(min=0, max=seq_len_kv)
end = cu_seqlen_ke.clamp(min=0, max=seq_len_kv)
count_ones_per_row = (end - start).clamp(min=0)
return count_ones_per_row.sum()

k = kv
q = q.float()
k = k.float()

mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None]
mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None]
mask = mask_lo & mask_hi

score = torch.einsum('mhd,nd->hmn', q, k)
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = logits.masked_fill(~mask, float('-inf'))

cost = mask.sum()
return logits, cost

def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor,
infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight) -> torch.Tensor:

q, k = self._get_q_k_bf16(hidden_states, q_lora, infer_state, layer_weight)
q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt)
k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt)

destindex_copy_indexer_ks(
k_fp8.unsqueeze(1),
k_scale.unsqueeze(1),
infer_state.mem_index,
infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_]
)

weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale
weights = weights.unsqueeze(-1) * q_scale

mem_index = infer_state.mem_index
ks = infer_state.ks
ke = infer_state.ke
lengths = infer_state.lengths
page_table_1 = infer_state.page_table_size_1

k_fp8_, k_scale_ = extract_indexer_ks(infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_], mem_index)

logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke)

# 返回 : [seq_q_len, topk] 无效的位置使用-1填充
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment is in Chinese. For better maintainability and consistency with the rest of the codebase, please translate it to English.

Suggested change
# 返回 : [seq_q_len, topk] 无效的位置使用-1填充
# Returns: [seq_q_len, topk], invalid positions are filled with -1

return fast_topk_transform_fused(
score=logits, # [seq_len_q, seq_len_kv]
lengths=lengths, # [seq_len_q]
page_table_size_1=page_table_1, # [seq_len_q, max(lengths)] 无效的使用0填充
cu_seqlens_q=infer_state.cu_seqlens_q, # [seq_len_q + 1]
topk=self.index_topk,
)


def get_k_float32_from_buffer(self, buffer: torch.Tensor):
k_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn)
k_scale = buffer[:, :, 128:].view(torch.float32)[:, :, :1]
k_float32 = k_fp8.float() * k_scale
return k_float32

@staticmethod
def _rotate_activation(x: torch.Tensor) -> torch.Tensor:
assert x.dtype == torch.bfloat16
from sgl_kernel import hadamard_transform

hidden_size = x.size(-1)
assert (
hidden_size & (hidden_size - 1)
) == 0, "Hidden size must be a power of 2 for Hadamard transform."
return hadamard_transform(x, scale=hidden_size**-0.5)

def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor,
infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight):
q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim)
k = layer_weight.wk_proj_.mm(hidden_states)

k = layernorm_forward(k, layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps)

rotary_emb_fwd(
q[:, :, : self.qk_rope_head_dim],
k[:, None, : self.qk_rope_head_dim],
infer_state.position_cos,
infer_state.position_sin,
)

q = self._rotate_activation(q)
k = self._rotate_activation(k)
return q, k
Loading
Loading