@@ -125,6 +125,16 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None)
125125 )
126126 self .tree_total_tokens_num .arr [0 ] = 0
127127
128+ # Hit rate tracking
129+ self .match_prefix_total_calls = SharedArray (
130+ f"{ unique_name } _match_prefix_total_calls_{ rank_in_node } " , (1 ,), dtype = np .int64
131+ )
132+ self .match_prefix_total_calls .arr [0 ] = 0
133+ self .match_prefix_hit_tokens = SharedArray (
134+ f"{ unique_name } _match_prefix_hit_tokens_{ rank_in_node } " , (1 ,), dtype = np .int64
135+ )
136+ self .match_prefix_hit_tokens .arr [0 ] = 0
137+
128138 def insert (self , key , value = None , buffer_idx = None ) -> Tuple [int , Optional [TreeNode ]]:
129139 if value is None :
130140 value = key
@@ -232,13 +242,21 @@ def _insert_helper_no_recursion(
232242
233243 def match_prefix (self , key , update_refs = False ):
234244 assert len (key ) != 0
245+
246+ # Track total calls
247+ self .match_prefix_total_calls .arr [0 ] += 1
248+
235249 ans_value_list = []
236250 tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = update_refs )
237251 if tree_node != self .root_node :
238252 if len (ans_value_list ) != 0 :
239253 value = torch .concat (ans_value_list )
240254 else :
241255 value = torch .zeros ((0 ,), device = "cpu" , dtype = self ._value_dtype )
256+
257+ # Track hit tokens
258+ self .match_prefix_hit_tokens .arr [0 ] += len (value )
259+
242260 return tree_node , len (value ), value
243261 else :
244262 self .dec_node_ref_counter (self .root_node )
@@ -509,6 +527,24 @@ def release_mem(mem_index):
509527 self .mem_manager .free (mem_index )
510528 return
511529
530+ def get_match_prefix_hit_rate (self ):
531+ """Get the hit rate as a ratio of hit tokens to total requested tokens"""
532+ total_calls = self .match_prefix_total_calls .arr [0 ]
533+ if total_calls == 0 :
534+ return 0.0
535+ # We calculate hit rate as the average hit tokens per call
536+ # Note: This is a simplified metric. For true hit rate, you might want to track total requested tokens
537+ total_hit_tokens = self .match_prefix_hit_tokens .arr [0 ]
538+ return total_hit_tokens / total_calls if total_calls > 0 else 0.0
539+
540+ def get_match_prefix_stats (self ):
541+ """Get detailed match_prefix statistics"""
542+ return {
543+ "total_calls" : self .match_prefix_total_calls .arr [0 ],
544+ "total_hit_tokens" : self .match_prefix_hit_tokens .arr [0 ],
545+ "hit_rate" : self .get_match_prefix_hit_rate (),
546+ }
547+
512548
513549class _RadixCacheReadOnlyClient :
514550 """
@@ -520,6 +556,13 @@ def __init__(self, unique_name, total_token_num, rank_in_node):
520556 self .tree_total_tokens_num = SharedArray (
521557 f"{ unique_name } _tree_total_tokens_num_{ rank_in_node } " , (1 ,), dtype = np .int64
522558 )
559+ # Hit rate tracking
560+ self .match_prefix_total_calls = SharedArray (
561+ f"{ unique_name } _match_prefix_total_calls_{ rank_in_node } " , (1 ,), dtype = np .int64
562+ )
563+ self .match_prefix_hit_tokens = SharedArray (
564+ f"{ unique_name } _match_prefix_hit_tokens_{ rank_in_node } " , (1 ,), dtype = np .int64
565+ )
523566
524567 def get_refed_tokens_num (self ):
525568 return self .refed_tokens_num .arr [0 ]
@@ -530,6 +573,22 @@ def get_tree_total_tokens_num(self):
530573 def get_unrefed_tokens_num (self ):
531574 return self .tree_total_tokens_num .arr [0 ] - self .refed_tokens_num .arr [0 ]
532575
576+ def get_match_prefix_hit_rate (self ):
577+ """Get the hit rate as a ratio of hit tokens to total calls"""
578+ total_calls = self .match_prefix_total_calls .arr [0 ]
579+ if total_calls == 0 :
580+ return 0.0
581+ total_hit_tokens = self .match_prefix_hit_tokens .arr [0 ]
582+ return total_hit_tokens / total_calls if total_calls > 0 else 0.0
583+
584+ def get_match_prefix_stats (self ):
585+ """Get detailed match_prefix statistics"""
586+ return {
587+ "total_calls" : self .match_prefix_total_calls .arr [0 ],
588+ "total_hit_tokens" : self .match_prefix_hit_tokens .arr [0 ],
589+ "hit_rate" : self .get_match_prefix_hit_rate (),
590+ }
591+
533592
534593class RadixCacheReadOnlyClient :
535594 def __init__ (self , unique_name , total_token_num , node_world_size , dp_world_size ):
@@ -546,3 +605,9 @@ def get_tree_total_tokens_num(self, dp_rank_in_node):
546605
547606 def get_unrefed_tokens_num (self , dp_rank_in_node ):
548607 return self .dp_rank_clients [dp_rank_in_node ].get_unrefed_tokens_num ()
608+
609+ def get_match_prefix_hit_rate (self , dp_rank_in_node ):
610+ return self .dp_rank_clients [dp_rank_in_node ].get_match_prefix_hit_rate ()
611+
612+ def get_match_prefix_stats (self , dp_rank_in_node ):
613+ return self .dp_rank_clients [dp_rank_in_node ].get_match_prefix_stats ()
0 commit comments