Skip to content

Commit 299ff47

Browse files
committed
reset
1 parent b158299 commit 299ff47

File tree

1 file changed

+6
-78
lines changed

1 file changed

+6
-78
lines changed

lightllm/server/router/dynamic_prompt/radix_cache.py

Lines changed: 6 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,6 @@ def __init__(self):
3131
self.node_value_len = 0
3232
self.node_prefix_total_len = 0
3333

34-
# 用于混合线性注意力模型, 例如Qwen3Next
35-
# 在混合线性注意力情景中,buffer_idx 可以有值也可以为None
36-
# 但是如果为None则不能作为最终改的匹配节点
37-
self.buffer_idx = None
38-
3934
def get_compare_key(self):
4035
return (0 if self.ref_counter == 0 else 1, len(self.children), self.time_id)
4136

@@ -130,16 +125,6 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None)
130125
)
131126
self.tree_total_tokens_num.arr[0] = 0
132127

133-
# Hit rate tracking
134-
self.match_prefix_total_calls = SharedArray(
135-
f"{unique_name}_match_prefix_total_calls_{rank_in_node}", (1,), dtype=np.int64
136-
)
137-
self.match_prefix_total_calls.arr[0] = 0
138-
self.match_prefix_hit_tokens = SharedArray(
139-
f"{unique_name}_match_prefix_hit_tokens_{rank_in_node}", (1,), dtype=np.int64
140-
)
141-
self.match_prefix_hit_tokens.arr[0] = 0
142-
143128
def insert(self, key, value=None) -> Tuple[int, Optional[TreeNode]]:
144129
if value is None:
145130
value = key
@@ -247,21 +232,13 @@ def _insert_helper_no_recursion(
247232

248233
def match_prefix(self, key, update_refs=False):
249234
assert len(key) != 0
250-
251-
# Track total calls
252-
self.match_prefix_total_calls.arr[0] += 1
253-
254235
ans_value_list = []
255236
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs)
256237
if tree_node != self.root_node:
257238
if len(ans_value_list) != 0:
258239
value = torch.concat(ans_value_list)
259240
else:
260241
value = torch.zeros((0,), device="cpu", dtype=self._value_dtype)
261-
262-
# Track hit tokens
263-
self.match_prefix_hit_tokens.arr[0] += len(value)
264-
265242
return tree_node, len(value), value
266243
else:
267244
self.dec_node_ref_counter(self.root_node)
@@ -343,29 +320,27 @@ def _match_prefix_helper_no_recursion(
343320
else:
344321
assert False, "error state"
345322

346-
def evict(self, need_remove_tokens, need_remove_buffers, evict_callback):
323+
def evict(self, need_remove_tokens, evict_callback):
347324
if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens:
348325
assert False, f"""can not free tree tokens {need_remove_tokens},
349326
tree_total_tokens_num {self.tree_total_tokens_num.arr[0]},
350327
refed_tokens_num {self.refed_tokens_num.arr[0]}"""
351328
num_evicted = 0
352-
release_buffers = []
353329
while num_evicted < need_remove_tokens:
354330
node: TreeNode = self.evict_tree_set.pop(0)
355331
assert (
356332
node.ref_counter == 0 and len(node.children) == 0 and node != self.root_node
357333
), "error evict tree node state"
358334
num_evicted += len(node.token_mem_index_value)
359335
evict_callback(node.token_mem_index_value)
360-
release_buffers.append(node.buffer_idx)
361336
# update total token num
362337
self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value)
363338
parent_node: TreeNode = node.parent
364339
parent_node.remove_child(node)
365340
if parent_node.is_leaf():
366341
self.evict_tree_set.add(parent_node)
367342

368-
return release_buffers
343+
return
369344

370345
def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]:
371346
"""
@@ -519,38 +494,20 @@ def _print_helper(self, node: TreeNode, indent):
519494
self._print_helper(child, indent=indent + 2)
520495
return
521496

522-
def free_radix_cache_to_get_enough_token(self, need_token_num, need_evict_buffer_num=0):
497+
def free_radix_cache_to_get_enough_token(self, need_token_num):
523498
assert self.mem_manager is not None
524-
if need_token_num > self.mem_manager.can_use_mem_size or need_evict_buffer_num > 0:
499+
if need_token_num > self.mem_manager.can_use_mem_size:
525500
need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size
526501
release_mems = []
527502

528503
def release_mem(mem_index):
529504
release_mems.append(mem_index)
530505
return
531506

532-
release_buffers = self.evict(need_evict_token_num, need_evict_buffer_num, release_mem)
507+
self.evict(need_evict_token_num, release_mem)
533508
mem_index = torch.concat(release_mems)
534509
self.mem_manager.free(mem_index)
535-
return release_buffers
536-
537-
def get_match_prefix_hit_rate(self):
538-
"""Get the hit rate as a ratio of hit tokens to total requested tokens"""
539-
total_calls = self.match_prefix_total_calls.arr[0]
540-
if total_calls == 0:
541-
return 0.0
542-
# We calculate hit rate as the average hit tokens per call
543-
# Note: This is a simplified metric. For true hit rate, you might want to track total requested tokens
544-
total_hit_tokens = self.match_prefix_hit_tokens.arr[0]
545-
return total_hit_tokens / total_calls if total_calls > 0 else 0.0
546-
547-
def get_match_prefix_stats(self):
548-
"""Get detailed match_prefix statistics"""
549-
return {
550-
"total_calls": self.match_prefix_total_calls.arr[0],
551-
"total_hit_tokens": self.match_prefix_hit_tokens.arr[0],
552-
"hit_rate": self.get_match_prefix_hit_rate(),
553-
}
510+
return
554511

555512

556513
class _RadixCacheReadOnlyClient:
@@ -563,13 +520,6 @@ def __init__(self, unique_name, total_token_num, rank_in_node):
563520
self.tree_total_tokens_num = SharedArray(
564521
f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64
565522
)
566-
# Hit rate tracking
567-
self.match_prefix_total_calls = SharedArray(
568-
f"{unique_name}_match_prefix_total_calls_{rank_in_node}", (1,), dtype=np.int64
569-
)
570-
self.match_prefix_hit_tokens = SharedArray(
571-
f"{unique_name}_match_prefix_hit_tokens_{rank_in_node}", (1,), dtype=np.int64
572-
)
573523

574524
def get_refed_tokens_num(self):
575525
return self.refed_tokens_num.arr[0]
@@ -580,22 +530,6 @@ def get_tree_total_tokens_num(self):
580530
def get_unrefed_tokens_num(self):
581531
return self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0]
582532

583-
def get_match_prefix_hit_rate(self):
584-
"""Get the hit rate as a ratio of hit tokens to total calls"""
585-
total_calls = self.match_prefix_total_calls.arr[0]
586-
if total_calls == 0:
587-
return 0.0
588-
total_hit_tokens = self.match_prefix_hit_tokens.arr[0]
589-
return total_hit_tokens / total_calls if total_calls > 0 else 0.0
590-
591-
def get_match_prefix_stats(self):
592-
"""Get detailed match_prefix statistics"""
593-
return {
594-
"total_calls": self.match_prefix_total_calls.arr[0],
595-
"total_hit_tokens": self.match_prefix_hit_tokens.arr[0],
596-
"hit_rate": self.get_match_prefix_hit_rate(),
597-
}
598-
599533

600534
class RadixCacheReadOnlyClient:
601535
def __init__(self, unique_name, total_token_num, node_world_size, dp_world_size):
@@ -612,9 +546,3 @@ def get_tree_total_tokens_num(self, dp_rank_in_node):
612546

613547
def get_unrefed_tokens_num(self, dp_rank_in_node):
614548
return self.dp_rank_clients[dp_rank_in_node].get_unrefed_tokens_num()
615-
616-
def get_match_prefix_hit_rate(self, dp_rank_in_node):
617-
return self.dp_rank_clients[dp_rank_in_node].get_match_prefix_hit_rate()
618-
619-
def get_match_prefix_stats(self, dp_rank_in_node):
620-
return self.dp_rank_clients[dp_rank_in_node].get_match_prefix_stats()

0 commit comments

Comments
 (0)