Skip to content

Commit ee5c4df

Browse files
committed
add radix cache hit rate
1 parent 11b012f commit ee5c4df

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed

lightllm/server/router/dynamic_prompt/radix_cache.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,16 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None)
125125
)
126126
self.tree_total_tokens_num.arr[0] = 0
127127

128+
# Hit rate tracking
129+
self.match_prefix_total_calls = SharedArray(
130+
f"{unique_name}_match_prefix_total_calls_{rank_in_node}", (1,), dtype=np.int64
131+
)
132+
self.match_prefix_total_calls.arr[0] = 0
133+
self.match_prefix_hit_tokens = SharedArray(
134+
f"{unique_name}_match_prefix_hit_tokens_{rank_in_node}", (1,), dtype=np.int64
135+
)
136+
self.match_prefix_hit_tokens.arr[0] = 0
137+
128138
def insert(self, key, value=None, buffer_idx=None) -> Tuple[int, Optional[TreeNode]]:
129139
if value is None:
130140
value = key
@@ -232,13 +242,21 @@ def _insert_helper_no_recursion(
232242

233243
def match_prefix(self, key, update_refs=False):
234244
assert len(key) != 0
245+
246+
# Track total calls
247+
self.match_prefix_total_calls.arr[0] += 1
248+
235249
ans_value_list = []
236250
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs)
237251
if tree_node != self.root_node:
238252
if len(ans_value_list) != 0:
239253
value = torch.concat(ans_value_list)
240254
else:
241255
value = torch.zeros((0,), device="cpu", dtype=self._value_dtype)
256+
257+
# Track hit tokens
258+
self.match_prefix_hit_tokens.arr[0] += len(value)
259+
242260
return tree_node, len(value), value
243261
else:
244262
self.dec_node_ref_counter(self.root_node)
@@ -509,6 +527,24 @@ def release_mem(mem_index):
509527
self.mem_manager.free(mem_index)
510528
return
511529

530+
def get_match_prefix_hit_rate(self):
531+
"""Get the hit rate as a ratio of hit tokens to total requested tokens"""
532+
total_calls = self.match_prefix_total_calls.arr[0]
533+
if total_calls == 0:
534+
return 0.0
535+
# We calculate hit rate as the average hit tokens per call
536+
# Note: This is a simplified metric. For true hit rate, you might want to track total requested tokens
537+
total_hit_tokens = self.match_prefix_hit_tokens.arr[0]
538+
return total_hit_tokens / total_calls if total_calls > 0 else 0.0
539+
540+
def get_match_prefix_stats(self):
541+
"""Get detailed match_prefix statistics"""
542+
return {
543+
"total_calls": self.match_prefix_total_calls.arr[0],
544+
"total_hit_tokens": self.match_prefix_hit_tokens.arr[0],
545+
"hit_rate": self.get_match_prefix_hit_rate(),
546+
}
547+
512548

513549
class _RadixCacheReadOnlyClient:
514550
"""
@@ -520,6 +556,13 @@ def __init__(self, unique_name, total_token_num, rank_in_node):
520556
self.tree_total_tokens_num = SharedArray(
521557
f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64
522558
)
559+
# Hit rate tracking
560+
self.match_prefix_total_calls = SharedArray(
561+
f"{unique_name}_match_prefix_total_calls_{rank_in_node}", (1,), dtype=np.int64
562+
)
563+
self.match_prefix_hit_tokens = SharedArray(
564+
f"{unique_name}_match_prefix_hit_tokens_{rank_in_node}", (1,), dtype=np.int64
565+
)
523566

524567
def get_refed_tokens_num(self):
525568
return self.refed_tokens_num.arr[0]
@@ -530,6 +573,22 @@ def get_tree_total_tokens_num(self):
530573
def get_unrefed_tokens_num(self):
531574
return self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0]
532575

576+
def get_match_prefix_hit_rate(self):
577+
"""Get the hit rate as a ratio of hit tokens to total calls"""
578+
total_calls = self.match_prefix_total_calls.arr[0]
579+
if total_calls == 0:
580+
return 0.0
581+
total_hit_tokens = self.match_prefix_hit_tokens.arr[0]
582+
return total_hit_tokens / total_calls if total_calls > 0 else 0.0
583+
584+
def get_match_prefix_stats(self):
585+
"""Get detailed match_prefix statistics"""
586+
return {
587+
"total_calls": self.match_prefix_total_calls.arr[0],
588+
"total_hit_tokens": self.match_prefix_hit_tokens.arr[0],
589+
"hit_rate": self.get_match_prefix_hit_rate(),
590+
}
591+
533592

534593
class RadixCacheReadOnlyClient:
535594
def __init__(self, unique_name, total_token_num, node_world_size, dp_world_size):
@@ -546,3 +605,9 @@ def get_tree_total_tokens_num(self, dp_rank_in_node):
546605

547606
def get_unrefed_tokens_num(self, dp_rank_in_node):
548607
return self.dp_rank_clients[dp_rank_in_node].get_unrefed_tokens_num()
608+
609+
def get_match_prefix_hit_rate(self, dp_rank_in_node):
610+
return self.dp_rank_clients[dp_rank_in_node].get_match_prefix_hit_rate()
611+
612+
def get_match_prefix_stats(self, dp_rank_in_node):
613+
return self.dp_rank_clients[dp_rank_in_node].get_match_prefix_stats()

0 commit comments

Comments
 (0)