1818logger = init_logger (__name__ )
1919
2020
21- class MemoryManager :
22- def __init__ (self , size , dtype , head_num , head_dim , layer_num , always_copy = False , mem_fraction = 0.9 ):
21+ class BaseAllocator :
22+ def __init__ (self , size , mem_manager_name = None ):
2323 self .size = size
24- self .head_num = head_num
25- self .head_dim = head_dim
26- self .layer_num = layer_num
27- self .always_copy = always_copy
28- self .dtype = dtype
29- # profile the max total token num if the size is None
30- self .profile_size (mem_fraction )
3124
3225 self .mem_state = torch .arange (
3326 0 , self .size , dtype = torch .int32 , device = "cpu" , requires_grad = False , pin_memory = True
@@ -42,22 +35,101 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
4235 self .can_use_mem_size = self .size
4336
4437 # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。
45- from lightllm . utils . envs_utils import get_unique_server_name
46-
38+ if mem_manager_name is None :
39+ mem_manager_name = get_unique_server_name ()
4740 rank_in_node = get_current_rank_in_node ()
48- self .shared_can_use_token_num = SharedInt (
49- f"{ get_unique_server_name ()} _mem_manger_can_use_token_num_{ rank_in_node } "
50- )
41+ self .shared_can_use_token_num = SharedInt (f"{ mem_manager_name } _mem_manger_can_use_token_num_{ rank_in_node } " )
42+
43+ self .shared_can_use_token_num .set_value (self .can_use_mem_size )
44+ self .HOLD_TOKEN_MEMINDEX = self .size
45+
46+ def alloc (self , need_size ) -> torch .Tensor :
47+ if need_size > self .mark_end - self .mark_start :
48+ logger .error (f"warn no enough cache need_size { need_size } left_size { self .can_use_mem_size } " )
49+ assert False , "error alloc state"
50+
51+ start = self .mark_start
52+ end = self .mark_start + need_size
53+ self .mark_start += need_size
54+
55+ self .can_use_mem_size -= need_size
56+ self .shared_can_use_token_num .set_value (self .can_use_mem_size )
57+
58+ # 利用缓冲区返回,避免异步情况下的内存竞争
59+ if self ._return_start + need_size > self ._mem_state_return .shape [0 ]:
60+ self ._return_start = 0
61+ ans = self ._mem_state_return [self ._return_start : self ._return_start + need_size ]
62+ ans .copy_ (self .mem_state [start :end ])
63+ self ._return_start += need_size
64+ return ans
65+
66+ def free (self , free_index : Union [torch .Tensor , List [int ]]):
67+ """_summary_
68+
69+ Args:
70+ free_index (torch.Tensor): _description_
71+ """
72+
73+ end = self .mark_start
74+ start = self .mark_start - len (free_index )
75+ assert start >= 0 , f"error free state start: { self .mark_start } free len { len (free_index )} "
76+
77+ if isinstance (free_index , list ):
78+ self .mem_state .numpy ()[start :end ] = free_index
79+ else :
80+ # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作
81+ self .mem_state [start :end ] = free_index
82+
83+ self .mark_start -= len (free_index )
84+
85+ self .can_use_mem_size += len (free_index )
86+ self .shared_can_use_token_num .set_value (self .can_use_mem_size )
87+
88+ if self .can_use_mem_size == len (self .mem_state ):
89+ logger .debug (f"freed all gpu mem size { self .can_use_mem_size } " )
90+ return
91+
92+ def free_all (self ):
93+ self .can_use_mem_size = len (self .mem_state )
94+ self .shared_can_use_token_num .set_value (self .can_use_mem_size )
95+ self .mem_state .numpy ()[:] = list (range (0 , len (self .mem_state )))
96+ self .mark_start = 0
97+ self .mark_end = len (self .mem_state )
5198
99+ def resize_mem (self , new_size ):
100+ """
101+ just for test code
102+ """
103+ self .size = new_size
104+ self .mem_state = torch .arange (
105+ 0 , self .size , dtype = torch .int32 , device = "cpu" , requires_grad = False , pin_memory = True
106+ )
107+ self .mark_start = 0
108+ self .mark_end = self .size
109+ self .can_use_mem_size = self .size
52110 self .shared_can_use_token_num .set_value (self .can_use_mem_size )
111+ return
112+
113+
114+ class MemoryManager (BaseAllocator ):
115+ def __init__ (self , size , dtype , head_num , head_dim , layer_num , always_copy = False , mem_fraction = 0.9 , mem_manager_name = None ):
116+ self .size = size
117+ self .head_num = head_num
118+ self .head_dim = head_dim
119+ self .layer_num = layer_num
120+ self .always_copy = always_copy
121+ self .dtype = dtype
122+ # profile the max total token num if the size is None
123+ self .profile_size (mem_fraction )
124+ super ().__init__ (self .siz , mem_manager_name )
125+
53126 self ._init_buffers (
54127 self .size ,
55128 dtype ,
56129 head_num ,
57130 head_dim ,
58131 layer_num ,
59132 )
60- self .HOLD_TOKEN_MEMINDEX = self .size
61133
62134 def get_cell_size (self ):
63135 return 2 * self .head_num * self .head_dim * self .layer_num * torch ._utils ._element_size (self .dtype )
@@ -93,7 +165,7 @@ def alloc_kv_move_buffer(self, max_req_total_len):
93165 """
94166 pd 分离模式使用的特殊接口
95167 """
96- if isinstance (self , MemoryManager ) and type (self ) != MemoryManager :
168+ if isinstance (self , MemoryManager ) and type (self ) is not MemoryManager :
97169 raise NotImplementedError ("subclass need reimpl this method" )
98170 self .kv_move_buffer = torch .empty (
99171 (1 , max_req_total_len + 8 , 2 * self .head_num , self .head_dim ), dtype = self .dtype , device = "cuda"
@@ -103,7 +175,7 @@ def alloc_kv_move_buffer(self, max_req_total_len):
103175 return
104176
105177 def alloc_paged_kv_move_buffer (self , page_num , page_size ) -> torch .Tensor :
106- if isinstance (self , MemoryManager ) and type (self ) != MemoryManager :
178+ if isinstance (self , MemoryManager ) and type (self ) is not MemoryManager :
107179 raise NotImplementedError ("subclass need reimpl this method" )
108180
109181 num_kv_head = get_num_key_value_heads (get_env_start_args ().model_dir )
@@ -320,59 +392,13 @@ def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: to
320392 def _free_buffers (self ):
321393 self .kv_buffer = None
322394
323- def alloc (self , need_size ) -> torch .Tensor :
324- if need_size > self .mark_end - self .mark_start :
325- logger .error (f"warn no enough cache need_size { need_size } left_size { self .can_use_mem_size } " )
326- assert False , "error alloc state"
327-
328- start = self .mark_start
329- end = self .mark_start + need_size
330- self .mark_start += need_size
331-
332- self .can_use_mem_size -= need_size
333- self .shared_can_use_token_num .set_value (self .can_use_mem_size )
334-
335- # 利用缓冲区返回,避免异步情况下的内存竞争
336- if self ._return_start + need_size > self ._mem_state_return .shape [0 ]:
337- self ._return_start = 0
338- ans = self ._mem_state_return [self ._return_start : self ._return_start + need_size ]
339- ans .copy_ (self .mem_state [start :end ])
340- self ._return_start += need_size
341- return ans
342-
343- def free (self , free_index : Union [torch .Tensor , List [int ]]):
344- """_summary_
345-
346- Args:
347- free_index (torch.Tensor): _description_
348- """
349-
350- end = self .mark_start
351- start = self .mark_start - len (free_index )
352- assert start >= 0 , f"error free state start: { self .mark_start } free len { len (free_index )} "
353-
354- if isinstance (free_index , list ):
355- self .mem_state .numpy ()[start :end ] = free_index
356- else :
357- # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作
358- self .mem_state [start :end ] = free_index
359-
360- self .mark_start -= len (free_index )
361-
362- self .can_use_mem_size += len (free_index )
363- self .shared_can_use_token_num .set_value (self .can_use_mem_size )
364-
365- if self .can_use_mem_size == len (self .mem_state ):
366- logger .debug (f"freed all gpu mem size { self .can_use_mem_size } " )
367- return
395+ def get_index_kv_buffer (self , index ):
396+ return {"kv_buffer" : self .kv_buffer [:, index ]}
368397
369- def free_all (self ):
370- self .can_use_mem_size = len (self .mem_state )
371- self .shared_can_use_token_num .set_value (self .can_use_mem_size )
372- self .mem_state .numpy ()[:] = list (range (0 , len (self .mem_state )))
373- self .mark_start = 0
374- self .mark_end = len (self .mem_state )
398+ def load_index_kv_buffer (self , index , load_tensor_dict ):
399+ self .kv_buffer [:, index ].copy_ (load_tensor_dict ["kv_buffer" ])
375400
401+ # 重写resize_mem方法,添加_free_buffers和_init_buffers调用
376402 def resize_mem (self , new_size ):
377403 """
378404 just for test code
@@ -383,24 +409,13 @@ def resize_mem(self, new_size):
383409 head_dim = self .head_dim
384410 layer_num = self .layer_num
385411
386- self .size = new_size
387- self .mem_state = torch .arange (
388- 0 , self .size , dtype = torch .int32 , device = "cpu" , requires_grad = False , pin_memory = True
389- )
390- self .mark_start = 0
391- self .mark_end = self .size
392- self .can_use_mem_size = self .size
393- self .shared_can_use_token_num .set_value (self .can_use_mem_size )
412+ # 调用父类的resize_mem
413+ super ().resize_mem (new_size )
414+
394415 self ._free_buffers ()
395416 self ._init_buffers (size , dtype , head_num , head_dim , layer_num )
396417 return
397418
398- def get_index_kv_buffer (self , index ):
399- return {"kv_buffer" : self .kv_buffer [:, index ]}
400-
401- def load_index_kv_buffer (self , index , load_tensor_dict ):
402- self .kv_buffer [:, index ].copy_ (load_tensor_dict ["kv_buffer" ])
403-
404419
405420class ReadOnlyStaticsMemoryManager :
406421 """
0 commit comments