@@ -31,11 +31,6 @@ def __init__(self):
3131 self .node_value_len = 0
3232 self .node_prefix_total_len = 0
3333
34- # 用于混合线性注意力模型, 例如Qwen3Next
35- # 在混合线性注意力情景中,buffer_idx 可以有值也可以为None
36- # 但是如果为None则不能作为最终改的匹配节点
37- self .buffer_idx = None
38-
3934 def get_compare_key (self ):
4035 return (0 if self .ref_counter == 0 else 1 , len (self .children ), self .time_id )
4136
@@ -130,16 +125,6 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None)
130125 )
131126 self .tree_total_tokens_num .arr [0 ] = 0
132127
133- # Hit rate tracking
134- self .match_prefix_total_calls = SharedArray (
135- f"{ unique_name } _match_prefix_total_calls_{ rank_in_node } " , (1 ,), dtype = np .int64
136- )
137- self .match_prefix_total_calls .arr [0 ] = 0
138- self .match_prefix_hit_tokens = SharedArray (
139- f"{ unique_name } _match_prefix_hit_tokens_{ rank_in_node } " , (1 ,), dtype = np .int64
140- )
141- self .match_prefix_hit_tokens .arr [0 ] = 0
142-
143128 def insert (self , key , value = None ) -> Tuple [int , Optional [TreeNode ]]:
144129 if value is None :
145130 value = key
@@ -247,21 +232,13 @@ def _insert_helper_no_recursion(
247232
248233 def match_prefix (self , key , update_refs = False ):
249234 assert len (key ) != 0
250-
251- # Track total calls
252- self .match_prefix_total_calls .arr [0 ] += 1
253-
254235 ans_value_list = []
255236 tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = update_refs )
256237 if tree_node != self .root_node :
257238 if len (ans_value_list ) != 0 :
258239 value = torch .concat (ans_value_list )
259240 else :
260241 value = torch .zeros ((0 ,), device = "cpu" , dtype = self ._value_dtype )
261-
262- # Track hit tokens
263- self .match_prefix_hit_tokens .arr [0 ] += len (value )
264-
265242 return tree_node , len (value ), value
266243 else :
267244 self .dec_node_ref_counter (self .root_node )
@@ -343,29 +320,27 @@ def _match_prefix_helper_no_recursion(
343320 else :
344321 assert False , "error state"
345322
346- def evict (self , need_remove_tokens , need_remove_buffers , evict_callback ):
323+ def evict (self , need_remove_tokens , evict_callback ):
347324 if self .tree_total_tokens_num .arr [0 ] - self .refed_tokens_num .arr [0 ] < need_remove_tokens :
348325 assert False , f"""can not free tree tokens { need_remove_tokens } ,
349326 tree_total_tokens_num { self .tree_total_tokens_num .arr [0 ]} ,
350327 refed_tokens_num { self .refed_tokens_num .arr [0 ]} """
351328 num_evicted = 0
352- release_buffers = []
353329 while num_evicted < need_remove_tokens :
354330 node : TreeNode = self .evict_tree_set .pop (0 )
355331 assert (
356332 node .ref_counter == 0 and len (node .children ) == 0 and node != self .root_node
357333 ), "error evict tree node state"
358334 num_evicted += len (node .token_mem_index_value )
359335 evict_callback (node .token_mem_index_value )
360- release_buffers .append (node .buffer_idx )
361336 # update total token num
362337 self .tree_total_tokens_num .arr [0 ] -= len (node .token_mem_index_value )
363338 parent_node : TreeNode = node .parent
364339 parent_node .remove_child (node )
365340 if parent_node .is_leaf ():
366341 self .evict_tree_set .add (parent_node )
367342
368- return release_buffers
343+ return
369344
370345 def _try_merge (self , child_node : TreeNode ) -> Optional [TreeNode ]:
371346 """
@@ -519,38 +494,20 @@ def _print_helper(self, node: TreeNode, indent):
519494 self ._print_helper (child , indent = indent + 2 )
520495 return
521496
522- def free_radix_cache_to_get_enough_token (self , need_token_num , need_evict_buffer_num = 0 ):
497+ def free_radix_cache_to_get_enough_token (self , need_token_num ):
523498 assert self .mem_manager is not None
524- if need_token_num > self .mem_manager .can_use_mem_size or need_evict_buffer_num > 0 :
499+ if need_token_num > self .mem_manager .can_use_mem_size :
525500 need_evict_token_num = need_token_num - self .mem_manager .can_use_mem_size
526501 release_mems = []
527502
528503 def release_mem (mem_index ):
529504 release_mems .append (mem_index )
530505 return
531506
532- release_buffers = self .evict (need_evict_token_num , need_evict_buffer_num , release_mem )
507+ self .evict (need_evict_token_num , release_mem )
533508 mem_index = torch .concat (release_mems )
534509 self .mem_manager .free (mem_index )
535- return release_buffers
536-
537- def get_match_prefix_hit_rate (self ):
538- """Get the hit rate as a ratio of hit tokens to total requested tokens"""
539- total_calls = self .match_prefix_total_calls .arr [0 ]
540- if total_calls == 0 :
541- return 0.0
542- # We calculate hit rate as the average hit tokens per call
543- # Note: This is a simplified metric. For true hit rate, you might want to track total requested tokens
544- total_hit_tokens = self .match_prefix_hit_tokens .arr [0 ]
545- return total_hit_tokens / total_calls if total_calls > 0 else 0.0
546-
547- def get_match_prefix_stats (self ):
548- """Get detailed match_prefix statistics"""
549- return {
550- "total_calls" : self .match_prefix_total_calls .arr [0 ],
551- "total_hit_tokens" : self .match_prefix_hit_tokens .arr [0 ],
552- "hit_rate" : self .get_match_prefix_hit_rate (),
553- }
510+ return
554511
555512
556513class _RadixCacheReadOnlyClient :
@@ -563,13 +520,6 @@ def __init__(self, unique_name, total_token_num, rank_in_node):
563520 self .tree_total_tokens_num = SharedArray (
564521 f"{ unique_name } _tree_total_tokens_num_{ rank_in_node } " , (1 ,), dtype = np .int64
565522 )
566- # Hit rate tracking
567- self .match_prefix_total_calls = SharedArray (
568- f"{ unique_name } _match_prefix_total_calls_{ rank_in_node } " , (1 ,), dtype = np .int64
569- )
570- self .match_prefix_hit_tokens = SharedArray (
571- f"{ unique_name } _match_prefix_hit_tokens_{ rank_in_node } " , (1 ,), dtype = np .int64
572- )
573523
574524 def get_refed_tokens_num (self ):
575525 return self .refed_tokens_num .arr [0 ]
@@ -580,22 +530,6 @@ def get_tree_total_tokens_num(self):
580530 def get_unrefed_tokens_num (self ):
581531 return self .tree_total_tokens_num .arr [0 ] - self .refed_tokens_num .arr [0 ]
582532
583- def get_match_prefix_hit_rate (self ):
584- """Get the hit rate as a ratio of hit tokens to total calls"""
585- total_calls = self .match_prefix_total_calls .arr [0 ]
586- if total_calls == 0 :
587- return 0.0
588- total_hit_tokens = self .match_prefix_hit_tokens .arr [0 ]
589- return total_hit_tokens / total_calls if total_calls > 0 else 0.0
590-
591- def get_match_prefix_stats (self ):
592- """Get detailed match_prefix statistics"""
593- return {
594- "total_calls" : self .match_prefix_total_calls .arr [0 ],
595- "total_hit_tokens" : self .match_prefix_hit_tokens .arr [0 ],
596- "hit_rate" : self .get_match_prefix_hit_rate (),
597- }
598-
599533
600534class RadixCacheReadOnlyClient :
601535 def __init__ (self , unique_name , total_token_num , node_world_size , dp_world_size ):
@@ -612,9 +546,3 @@ def get_tree_total_tokens_num(self, dp_rank_in_node):
612546
613547 def get_unrefed_tokens_num (self , dp_rank_in_node ):
614548 return self .dp_rank_clients [dp_rank_in_node ].get_unrefed_tokens_num ()
615-
616- def get_match_prefix_hit_rate (self , dp_rank_in_node ):
617- return self .dp_rank_clients [dp_rank_in_node ].get_match_prefix_hit_rate ()
618-
619- def get_match_prefix_stats (self , dp_rank_in_node ):
620- return self .dp_rank_clients [dp_rank_in_node ].get_match_prefix_stats ()
0 commit comments