Skip to content

Commit 1f574f2

Browse files
committed
draft
1 parent 299ff47 commit 1f574f2

File tree

15 files changed

+175
-222
lines changed

15 files changed

+175
-222
lines changed

lightllm/common/basemodel/infer_struct.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ def __init__(self):
8888
self.dp_output_split_sizes: List[List[int]] = None
8989
self.dp_input_split_sizes: List[List[int]] = None
9090

91+
# 专门用于管理混合注意力模型的buffer
92+
self.buffer_indexes: torch.Tensor = None
93+
9194
def init_some_extra_state(self, model, input_ids: torch.Tensor):
9295
if self.is_prefill:
9396
(

lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from lightllm.utils.log_utils import init_logger
88
from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd
99
from lightllm.models.qwen3next.mem_manager import Qwen3NextMemoryManager
10-
from lightllm.models.qwen3next.req_manager import Qwen3NextReqManager
1110
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
1211
from typing import Tuple
1312
from typing_extensions import override
@@ -250,7 +249,7 @@ def _linear_attn(
250249
):
251250
assert layer_weight.is_linear, "layer_weight must be linear"
252251
assert isinstance(infer_state.mem_manager, Qwen3NextMemoryManager)
253-
assert isinstance(infer_state.req_manager, Qwen3NextReqManager)
252+
254253
input = input.view(-1, infer_cls.embed_dim_)
255254
buffer_idx = infer_state.req_manager.req_to_buffer_indexes[infer_state.b_req_idx]
256255
conv_states, ssm_states = infer_state.mem_manager.get_state_cache_buffer(self.layer_idx_)

lightllm/models/qwen3next/mem_manager.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from lightllm.utils.envs_utils import get_unique_server_name
1010
from lightllm.utils.dist_utils import get_current_rank_in_node
1111
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
12+
from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridMemManager
1213

1314
logger = init_logger(__name__)
1415

@@ -34,24 +35,7 @@ def get_cell_size(self):
3435
return np.prod(self.shape) * self.layer_num * torch._utils._element_size(self.dtype)
3536

3637

37-
class HaveStateBuffer(Protocol):
38-
def alloc_state_cache_buffer(self, need_size):
39-
...
40-
41-
def free_state_cache_buffer(self, free_buffer_indexes):
42-
...
43-
44-
def get_state_cache_buffer(self, layer_index):
45-
...
46-
47-
def get_state_cache_can_use_size(self):
48-
...
49-
50-
def copy_state_cache_buffer(self, src_idx, tgt_idx):
51-
pass
52-
53-
54-
class Qwen3NextMemoryManager(MemoryManager, HaveStateBuffer):
38+
class Qwen3NextMemoryManager(HybridMemManager):
5539
def __init__(
5640
self,
5741
full_attn_cache_size,
@@ -121,32 +105,32 @@ def free_all(self):
121105
return
122106

123107
@override
124-
def get_state_cache_buffer(self, layer_index) -> Tuple[torch.Tensor, torch.Tensor]:
108+
def get_buffer(self, layer_index) -> Tuple[torch.Tensor, torch.Tensor]:
125109
assert layer_index < self.layer_num, "layer_index is out of range"
126110
assert (layer_index + 1) % self.full_attention_interval != 0, "layer_index is not linear attention layer"
127111
real_layer_index = layer_index - layer_index // self.full_attention_interval
128112
return self.conv_state_mem_manager.buffer[real_layer_index], self.ssm_state_mem_manager.buffer[real_layer_index]
129113

130114
@override
131-
def free_state_cache_buffer(self, free_buffer_indexes: List[int], reset=True):
115+
def free_buffer(self, free_buffer_indexes: List[int], reset=True):
132116
# conv_state 和 ssm_state 共享buffer_idx
133117
self.conv_state_mem_manager.free(free_buffer_indexes)
134118
if reset:
135119
self.conv_state_mem_manager.buffer[:, free_buffer_indexes] = 0
136120
self.ssm_state_mem_manager.buffer[:, free_buffer_indexes] = 0
137121

138122
@override
139-
def alloc_state_cache_buffer(self, need_size):
123+
def alloc_buffer(self, need_size):
140124
# conv_state 和 ssm_state 共享buffer_idx
141125
buffer_indexes = self.conv_state_mem_manager.alloc(need_size)
142126
return buffer_indexes
143127

144128
@override
145-
def get_state_cache_can_use_size(self):
129+
def get_buffer_can_use_size(self):
146130
return self.conv_state_mem_manager.can_use_mem_size
147131

148132
@override
149-
def copy_state_cache_buffer(self, src_idx, tgt_idx):
133+
def copy_buffer(self, src_idx, tgt_idx):
150134
assert src_idx is not None and tgt_idx is not None
151135
assert src_idx != tgt_idx
152136
# Use slice operation and in-place copy for better performance

lightllm/models/qwen3next/model.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from lightllm.distributed.communication_op import dist_group_manager
1010
from lightllm.utils.envs_utils import get_env_start_args
1111
from lightllm.models.qwen3next.mem_manager import Qwen3NextMemoryManager
12-
from lightllm.models.qwen3next.req_manager import Qwen3NextReqManager
1312
from lightllm.server.core.objs.start_args_type import StartArgs
13+
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
1414

1515
logger = init_logger(__name__)
1616

@@ -25,6 +25,7 @@ class Qwen3NextTpPartModel(Qwen3MOEModel):
2525
post_layer_infer_class = Qwen3NextPostLayerInfer
2626

2727
def __init__(self, kvargs) -> None:
28+
self.mem_manager: Qwen3NextMemoryManager = None
2829
super().__init__(kvargs)
2930

3031
@override
@@ -85,13 +86,15 @@ def _init_mem_manager(self):
8586
mem_fraction=self.mem_fraction,
8687
)
8788

88-
@override
89-
def _init_req_manager(self):
90-
create_max_seq_len = 0
91-
92-
if self.batch_max_tokens is not None:
93-
create_max_seq_len = max(create_max_seq_len, self.batch_max_tokens)
94-
if self.max_seq_length is not None:
95-
create_max_seq_len = max(create_max_seq_len, self.max_seq_length)
96-
97-
self.req_manager = Qwen3NextReqManager(self.max_req_num, create_max_seq_len, self.mem_manager)
89+
def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0):
90+
from lightllm.common.basemodel.infer_lock import g_infer_state_lock
91+
from lightllm.common.basemodel.infer_context import g_infer_context
92+
93+
infer_state = super()._create_inferstate(model_input, microbatch_index)
94+
g_infer_state_lock.acquire()
95+
if g_infer_context.radix_cache is not None:
96+
g_infer_context.radix_cache.free_radix_cache_to_get_enough_buffer(infer_state.batch_size)
97+
buffer_indexes = self.mem_manager.alloc_buffer(infer_state.batch_size)
98+
g_infer_state_lock.release()
99+
infer_state.buffer_indexes = buffer_indexes
100+
return infer_state

lightllm/models/qwen3next/req_manager.py

Lines changed: 0 additions & 42 deletions
This file was deleted.

lightllm/server/core/objs/start_args_type.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ class StartArgs:
6363
token_healing_mode: bool = field(default=False)
6464
output_constraint_mode: str = field(default="none", metadata={"choices": ["outlines", "xgrammar", "none"]})
6565
first_token_constraint_mode: bool = field(default=False)
66-
enable_multimodal: bool = field(default=False)
67-
enable_multimodal_audio: bool = field(default=False)
6866
enable_tpsp_mix_mode: bool = field(default=False)
6967
enable_dp_prefill_balance: bool = field(default=False)
7068
enable_decode_microbatch_overlap: bool = field(default=False)
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from typing import Set, Protocol, List
2+
3+
import torch
4+
from sortedcontainers import SortedSet
5+
6+
from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode
7+
from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager
8+
from lightllm.server.router.model_infer.infer_batch import InferReq
9+
10+
11+
class HybridMemManager(MemoryManager):
12+
def alloc_buffer(self, need_size):
13+
...
14+
15+
def free_buffer(self, free_buffer_indexes):
16+
...
17+
18+
def get_buffer(self, layer_index):
19+
...
20+
21+
def get_buffer_can_use_size(self):
22+
...
23+
24+
def copy_buffer(self, src_idx, tgt_idx):
25+
...
26+
27+
28+
class HybridRadixCache(RadixCache):
29+
def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None):
30+
self.mem_manager: HybridMemManager = mem_manager
31+
super().__init__(unique_name, total_token_num, rank_in_node, mem_manager)
32+
self.evict_buffer_set: Set[TreeNode] = SortedSet(key=lambda x: x.time_id)
33+
self.evict_buffer_set.add(self.root_node)
34+
35+
def free_radix_cache_to_get_enough_buffer(self, need_buffer_num):
36+
if need_buffer_num > self.mem_manager.get_buffer_can_use_size():
37+
need_evict_buffer_num = need_buffer_num - self.mem_manager.get_buffer_can_use_size()
38+
39+
release_mems = []
40+
41+
def release_mem(mem_index):
42+
release_mems.append(mem_index)
43+
return
44+
45+
release_buffers = []
46+
47+
def release_buffer(buffer_idx):
48+
release_buffers.append(buffer_idx)
49+
return
50+
51+
self.evict_buffer(need_evict_buffer_num, release_buffer, release_mem)
52+
self.mem_manager.free_buffer(release_buffers)
53+
if len(release_mems) > 0:
54+
mem_index = torch.concat(release_mems)
55+
self.mem_manager.free(mem_index)
56+
return
57+
58+
def evict_buffer(self, need_evict_buffer_num, evict_buffer_callback, evict_token_callback):
59+
while need_evict_buffer_num > 0:
60+
node = self.evict_buffer_set.pop()
61+
if node.buffer_idx is not None:
62+
evict_buffer_callback(node.buffer_idx)
63+
need_evict_buffer_num -= 1
64+
else:
65+
# 在混合注意力模型的情景里,只能匹配 buffer_idx 不为 None的节点
66+
# 假如 buffer_idx 为 None,则当做匹配失败。
67+
# 所以可以直接把这个节点给释放掉
68+
if node.is_leaf() and node.ref_counter == 0:
69+
self._remove_leaf_node(node)
70+
return
71+
72+
def insert_for_hybrid_radix_cache(self, reqs: List["InferReq"]):
73+
from lightllm.server.router.model_infer.infer_batch import g_infer_context
74+
from lightllm.common.basemodel.infer_lock import g_infer_state_lock
75+
76+
# 确保有足够的空间用于新的 buffer
77+
g_infer_state_lock.acquire()
78+
self.free_radix_cache_to_get_enough_buffer(len(reqs))
79+
new_buffer_indexes = self.mem_manager.alloc_buffer(len(reqs))
80+
g_infer_state_lock.release()
81+
82+
for i, req in enumerate(reqs):
83+
input_token_ids = req.get_input_token_ids()
84+
key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu")
85+
value = g_infer_context.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].cpu()
86+
buffer_idx = req.buffer_idx
87+
88+
# 分配新的 buffer 并复制当前 buffer 的内容
89+
self.mem_manager.copy_buffer(buffer_idx, new_buffer_indexes[i])
90+
req.buffer_idx = new_buffer_indexes[i]
91+
92+
_, new_shared_kv_node = self.insert(key, value)
93+
new_shared_kv_node.buffer_idx = buffer_idx
94+
self.dec_node_ref_counter(req.shared_kv_node)
95+
self.add_node_ref_counter(new_shared_kv_node)
96+
req.shared_kv_node = new_shared_kv_node
97+
98+
def match_prefix(self, key, update_refs=False):
99+
assert len(key) != 0
100+
ans_value_list = []
101+
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs)
102+
103+
while tree_node != self.root_node and tree_node.buffer_idx is None:
104+
self.dec_node_ref_counter(tree_node)
105+
if tree_node.is_leaf() and tree_node.ref_counter == 0:
106+
tree_node = self._remove_leaf_node(tree_node)
107+
else:
108+
tree_node = tree_node.parent
109+
ans_value_list.pop()
110+
111+
if tree_node == self.root_node:
112+
return None, 0, None
113+
114+
value = torch.concat(ans_value_list)
115+
return tree_node, len(value), value
116+
117+
def _remove_leaf_node(self, node: TreeNode):
118+
self.evict_tree_set.discard(node)
119+
self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value)
120+
parent_node: TreeNode = node.parent
121+
parent_node.remove_child(node)
122+
if parent_node.is_leaf():
123+
self.evict_tree_set.add(parent_node)
124+
return parent_node

lightllm/server/router/dynamic_prompt/radix_cache.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ def __init__(self):
3131
self.node_value_len = 0
3232
self.node_prefix_total_len = 0
3333

34+
# 专门用于管理混合注意力模型(例如 Qwen3Next),
35+
# 该类模型每个请求需要管理一个唯一的buffer_idx,
36+
# 放在这里让该类模型能够复用当前的radix_cache代码。
37+
# 纯注意力模型该 buffer_idx 始终保持为 None
38+
self.buffer_idx = None
39+
3440
def get_compare_key(self):
3541
return (0 if self.ref_counter == 0 else 1, len(self.children), self.time_id)
3642

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,8 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq"):
113113
# .cpu() 是 流内阻塞操作
114114
value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu()
115115

116-
buffer_idx = None
117-
if hasattr(self.req_manager, "req_to_buffer_indexes"):
118-
buffer_idx = self.req_manager.req_to_buffer_indexes[req.req_idx].cpu()
119-
prefix_len, _ = self.radix_cache.insert(key, value, buffer_idx=buffer_idx)
116+
prefix_len, node = self.radix_cache.insert(key, value)
117+
node.buffer_idx = req.buffer_idx
120118
old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len
121119
free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len])
122120
if req.shared_kv_node is not None:
@@ -345,6 +343,10 @@ def __init__(
345343
self.nixl_pd_task_failed_num: int = 0
346344
self.nixl_trans_device_id: int = -1
347345

346+
# 可以用于请求在整个生命周期维护单一大小的buffer的场景
347+
# 例如混合注意力模型 Qwen3Next
348+
self.buffer_idx = -1
349+
348350
# 在开启 enable_cpu_cache 的情况下,当请求结束后,会将请求的 kv cache
349351
# 卸载到 cpu cache 中,该标志变量用于标记请求的卸载任务的状态
350352
self.cpu_cache_task_status: "InferReq._CpuCacheTaskStatus" = InferReq._CpuCacheTaskStatus.NOT_STARTED
@@ -397,26 +399,14 @@ def _match_radix_cache(self):
397399
key = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu")
398400
key = key[0 : len(key) - 1] # 最后一个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值
399401
share_node, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True)
400-
401-
if share_node is not None:
402-
if g_infer_context.use_hybrid_radix_cache:
403-
if share_node.buffer_idx is None:
404-
g_infer_context.radix_cache.dec_node_ref_counter(share_node)
405-
share_node = None
406-
407402
if share_node is not None:
408403
self.shared_kv_node = share_node
409404
ready_cache_len = share_node.node_prefix_total_len
410405
# 从 cpu 到 gpu 是流内阻塞操作
411406
g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor
412407
self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换
413408
self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度
414-
415-
if g_infer_context.use_hybrid_radix_cache:
416-
cur_buffer_idx = g_infer_context.req_manager.req_to_buffer_indexes[self.req_idx]
417-
g_infer_context.req_manager.mem_manager.copy_state_cache_buffer(
418-
share_node.buffer_idx, cur_buffer_idx
419-
)
409+
self.buffer_idx = share_node.buffer_idx
420410

421411
self.shm_req.shm_cur_kv_len = self.cur_kv_len
422412
return

lightllm/server/router/model_infer/mode_backend/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from .chunked_prefill.impl import ChunkedPrefillBackend
22
from .chunked_prefill.impl_for_first_token_constraint_mode import FirstTokenConstraintBackend
33
from .chunked_prefill.impl_for_outlines_constraint_mode import OutlinesConstraintBackend
4-
from .chunked_prefill.impl_for_hybrid_radix_cache import HybridRadixCacheBackend
54
from .chunked_prefill.impl_for_return_all_prompt_logprobs import ReturnPromptLogProbBackend
65
from .chunked_prefill.impl_for_reward_model import RewardModelBackend
76
from .chunked_prefill.impl_for_token_healing import TokenHealingBackend

0 commit comments

Comments
 (0)