Skip to content

Commit 00fb9d7

Browse files
committed
tmp
1 parent 1f574f2 commit 00fb9d7

File tree

9 files changed

+71
-19
lines changed

9 files changed

+71
-19
lines changed

lightllm/common/req_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter
88
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args
99
from lightllm.utils.config_utils import get_vocab_size
10+
from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridMemManager
1011

1112
logger = init_logger(__name__)
1213

lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,8 @@ def _linear_attn(
251251
assert isinstance(infer_state.mem_manager, Qwen3NextMemoryManager)
252252

253253
input = input.view(-1, infer_cls.embed_dim_)
254-
buffer_idx = infer_state.req_manager.req_to_buffer_indexes[infer_state.b_req_idx]
255-
conv_states, ssm_states = infer_state.mem_manager.get_state_cache_buffer(self.layer_idx_)
254+
buffer_idx = infer_state.buffer_indexes
255+
conv_states, ssm_states = infer_state.mem_manager.get_buffer(self.layer_idx_)
256256

257257
mixed_qkvzba = layer_weight.linear_in_proj.mm(input)
258258
q, k, v, z, b, a = self._fix_query_key_value_ba_ordering(mixed_qkvzba)

lightllm/models/qwen3next/mem_manager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(
5050
conv_state_shape: Tuple[int, ...],
5151
ssm_state_dtype: torch.dtype,
5252
ssm_state_shape: Tuple[int, ...],
53+
max_req_num: int,
5354
always_copy=False,
5455
mem_fraction=0.9,
5556
):
@@ -80,8 +81,6 @@ def __init__(
8081
f"Ssm state use : "
8182
f"{self.ssm_state_mem_manager.get_cell_size() * linear_attn_cache_size / 1024 ** 3} GB Memory.\n"
8283
)
83-
self.EMPTY_BUFFER_INDEX = -1
84-
self.HOLD_BUFFER_INDEX = self.conv_state_mem_manager.HOLD_TOKEN_MEMINDEX
8584
super().__init__(full_attn_cache_size, dtype, num_kv_heads, head_dim, layer_num, always_copy, mem_fraction)
8685

8786
@override

lightllm/models/qwen3next/model.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from lightllm.models.qwen3next.mem_manager import Qwen3NextMemoryManager
1212
from lightllm.server.core.objs.start_args_type import StartArgs
1313
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
14+
from lightllm.models.qwen3next.req_manager import Qwen3NextReqManager
1415

1516
logger = init_logger(__name__)
1617

@@ -83,18 +84,42 @@ def _init_mem_manager(self):
8384
self.head_linear_k_dim,
8485
self.head_linear_v_dim,
8586
),
87+
max_req_num=self.max_req_num,
8688
mem_fraction=self.mem_fraction,
8789
)
8890

91+
@override
8992
def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0):
9093
from lightllm.common.basemodel.infer_lock import g_infer_state_lock
91-
from lightllm.common.basemodel.infer_context import g_infer_context
94+
from lightllm.server.router.model_infer.infer_batch import g_infer_context
9295

9396
infer_state = super()._create_inferstate(model_input, microbatch_index)
97+
98+
buffer_indexes = self.req_manager.req_to_buffer_indexes[model_input.b_req_idx]
99+
empty_indexes = buffer_indexes == self.req_manager.EMPTY_BUFFER_INDEX
100+
num_empty = empty_indexes.sum()
101+
if num_empty == 0:
102+
return infer_state
103+
94104
g_infer_state_lock.acquire()
95105
if g_infer_context.radix_cache is not None:
96-
g_infer_context.radix_cache.free_radix_cache_to_get_enough_buffer(infer_state.batch_size)
97-
buffer_indexes = self.mem_manager.alloc_buffer(infer_state.batch_size)
106+
g_infer_context.radix_cache.free_radix_cache_to_get_enough_buffer(num_empty)
107+
new_buffer_indexes = self.mem_manager.alloc_buffer(num_empty).cuda()
98108
g_infer_state_lock.release()
109+
110+
buffer_indexes[empty_indexes] = new_buffer_indexes
111+
self.req_manager.req_to_buffer_indexes[model_input.b_req_idx] = buffer_indexes
99112
infer_state.buffer_indexes = buffer_indexes
100113
return infer_state
114+
115+
@override
116+
def _init_req_manager(self):
117+
create_max_seq_len = 0
118+
119+
if self.batch_max_tokens is not None:
120+
create_max_seq_len = max(create_max_seq_len, self.batch_max_tokens)
121+
if self.max_seq_length is not None:
122+
create_max_seq_len = max(create_max_seq_len, self.max_seq_length)
123+
124+
self.req_manager = Qwen3NextReqManager(self.max_req_num, create_max_seq_len, self.mem_manager)
125+
return
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import override, List
2+
3+
import torch
4+
5+
from lightllm.common.req_manager import ReqManager
6+
from lightllm.models.qwen3next.mem_manager import Qwen3NextMemoryManager
7+
8+
9+
class Qwen3NextReqManager(ReqManager):
10+
def __init__(self, max_request_num, max_sequence_length, mem_manager: Qwen3NextMemoryManager):
11+
super().__init__(max_request_num, max_sequence_length, mem_manager)
12+
self.EMPTY_BUFFER_INDEX = -1
13+
self.req_to_buffer_indexes = torch.zeros((self.max_request_num + 1), dtype=torch.int32, device="cuda")
14+
self.req_to_buffer_indexes[:] = self.EMPTY_BUFFER_INDEX
15+
16+
@override
17+
def free(self, free_req_indexes: List[int], free_token_index):
18+
self.free_buffer(free_req_indexes)
19+
super().free(free_req_indexes, free_token_index)
20+
21+
@override
22+
def free_all(self):
23+
self.req_to_buffer_indexes[:] = self.EMPTY_BUFFER_INDEX
24+
super().free_all()
25+
return
26+
27+
def free_buffer(self, free_req_indexes: List[int]):
28+
from lightllm.server.router.model_infer.infer_batch import g_infer_context
29+
30+
if g_infer_context.radix_cache is None:
31+
self.mem_manager.free_buffer(self.req_to_buffer_indexes[free_req_indexes])
32+
self.req_to_buffer_indexes[free_req_indexes] = self.EMPTY_BUFFER_INDEX
33+
return

lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode
77
from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager
8-
from lightllm.server.router.model_infer.infer_batch import InferReq
98

109

1110
class HybridMemManager(MemoryManager):
@@ -69,7 +68,7 @@ def evict_buffer(self, need_evict_buffer_num, evict_buffer_callback, evict_token
6968
self._remove_leaf_node(node)
7069
return
7170

72-
def insert_for_hybrid_radix_cache(self, reqs: List["InferReq"]):
71+
def insert_for_hybrid_radix_cache(self, reqs):
7372
from lightllm.server.router.model_infer.infer_batch import g_infer_context
7473
from lightllm.common.basemodel.infer_lock import g_infer_state_lock
7574

lightllm/server/router/manager.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -251,22 +251,13 @@ async def loop_for_fwd(
251251
estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(d_i)
252252
paused_req_num = self._get_paused_req_num_in_dp_index(dp_index=d_i)
253253

254-
# Get hit rate from radix cache if available
255-
hit_rate = 0.0
256-
if self.radix_cache_client is not None:
257-
try:
258-
hit_rate = self.radix_cache_client.get_match_prefix_hit_rate(d_i)
259-
except Exception as e:
260-
logger.warning(f"Failed to get hit rate from radix cache: {e}")
261-
262254
logger.debug(
263255
f"dp_i {d_i} current batch size: {len(self.running_batch.reqs)} \n"
264256
f"dp_i {d_i} paused req num: {paused_req_num} \n"
265257
f"dp_i {d_i} frozen token num: {frozen_token_num} \n"
266258
f"dp_i {d_i} estimated_peak_token_count: {estimated_peak_token_count} \n"
267259
f"dp_i {d_i} token used ratio: {token_ratio1} not contain prompt cache tree unrefed token\n"
268260
f"dp_i {d_i} token used ratio: {token_ratio2} contain prompt cache tree unrefed token\n"
269-
f"dp_i {d_i} match_prefix hit_rate: {hit_rate:.4f}"
270261
)
271262
self.metric_client.gauge_set("lightllm_batch_pause_size", paused_req_num)
272263
# pd decode mode need to update token_load more frequently

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,10 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool):
183183
if pause_reqs:
184184
g_infer_state_lock.acquire()
185185

186+
pause_req_ids = []
186187
free_token_index = []
187188
for req in pause_reqs:
189+
pause_req_ids.append(req.req_id)
188190
if self.args.diverse_mode:
189191
# 发生暂停的时候,需要清除 diverse 模式下的主从关系
190192
req.clear_master_slave_state()
@@ -201,6 +203,9 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool):
201203
free_token_index = custom_cat(free_token_index)
202204
self.req_manager.free_token(free_token_index)
203205

206+
if hasattr(self.req_manager, "free_buffer"):
207+
self.req_manager.free_buffer(pause_req_ids)
208+
204209
g_infer_state_lock.release()
205210
return self
206211

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ def init_model(self, kvargs):
192192
shm_req_manager=self.shm_req_manager,
193193
vocab_size=self.model.vocab_size,
194194
)
195-
196195
# 初始化 dp 模式使用的通信 tensor, 对于非dp模式,不会使用到
197196
if self.dp_size > 1:
198197
self.dp_reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False)

0 commit comments

Comments
 (0)