Skip to content

Commit 5e2375d

Browse files
committed
Add BaseAllocator
1 parent 1f54ecd commit 5e2375d

File tree

1 file changed

+97
-82
lines changed

1 file changed

+97
-82
lines changed

lightllm/common/mem_manager.py

Lines changed: 97 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,9 @@
1818
logger = init_logger(__name__)
1919

2020

21-
class MemoryManager:
22-
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
21+
class BaseAllocator:
22+
def __init__(self, size, mem_manager_name=None):
2323
self.size = size
24-
self.head_num = head_num
25-
self.head_dim = head_dim
26-
self.layer_num = layer_num
27-
self.always_copy = always_copy
28-
self.dtype = dtype
29-
# profile the max total token num if the size is None
30-
self.profile_size(mem_fraction)
3124

3225
self.mem_state = torch.arange(
3326
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
@@ -42,22 +35,101 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
4235
self.can_use_mem_size = self.size
4336

4437
# 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。
45-
from lightllm.utils.envs_utils import get_unique_server_name
46-
38+
if mem_manager_name is None:
39+
mem_manager_name = get_unique_server_name()
4740
rank_in_node = get_current_rank_in_node()
48-
self.shared_can_use_token_num = SharedInt(
49-
f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}"
50-
)
41+
self.shared_can_use_token_num = SharedInt(f"{mem_manager_name}_mem_manger_can_use_token_num_{rank_in_node}")
42+
43+
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
44+
self.HOLD_TOKEN_MEMINDEX = self.size
45+
46+
def alloc(self, need_size) -> torch.Tensor:
47+
if need_size > self.mark_end - self.mark_start:
48+
logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
49+
assert False, "error alloc state"
50+
51+
start = self.mark_start
52+
end = self.mark_start + need_size
53+
self.mark_start += need_size
54+
55+
self.can_use_mem_size -= need_size
56+
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
57+
58+
# 利用缓冲区返回,避免异步情况下的内存竞争
59+
if self._return_start + need_size > self._mem_state_return.shape[0]:
60+
self._return_start = 0
61+
ans = self._mem_state_return[self._return_start : self._return_start + need_size]
62+
ans.copy_(self.mem_state[start:end])
63+
self._return_start += need_size
64+
return ans
65+
66+
def free(self, free_index: Union[torch.Tensor, List[int]]):
67+
"""_summary_
68+
69+
Args:
70+
free_index (torch.Tensor): _description_
71+
"""
72+
73+
end = self.mark_start
74+
start = self.mark_start - len(free_index)
75+
assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}"
76+
77+
if isinstance(free_index, list):
78+
self.mem_state.numpy()[start:end] = free_index
79+
else:
80+
# 从 gpu 到 cpu 的拷贝操作是流内阻塞操作
81+
self.mem_state[start:end] = free_index
82+
83+
self.mark_start -= len(free_index)
84+
85+
self.can_use_mem_size += len(free_index)
86+
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
87+
88+
if self.can_use_mem_size == len(self.mem_state):
89+
logger.debug(f"freed all gpu mem size {self.can_use_mem_size}")
90+
return
91+
92+
def free_all(self):
93+
self.can_use_mem_size = len(self.mem_state)
94+
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
95+
self.mem_state.numpy()[:] = list(range(0, len(self.mem_state)))
96+
self.mark_start = 0
97+
self.mark_end = len(self.mem_state)
5198

99+
def resize_mem(self, new_size):
100+
"""
101+
just for test code
102+
"""
103+
self.size = new_size
104+
self.mem_state = torch.arange(
105+
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
106+
)
107+
self.mark_start = 0
108+
self.mark_end = self.size
109+
self.can_use_mem_size = self.size
52110
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
111+
return
112+
113+
114+
class MemoryManager(BaseAllocator):
115+
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, mem_manager_name=None):
116+
self.size = size
117+
self.head_num = head_num
118+
self.head_dim = head_dim
119+
self.layer_num = layer_num
120+
self.always_copy = always_copy
121+
self.dtype = dtype
122+
# profile the max total token num if the size is None
123+
self.profile_size(mem_fraction)
124+
super().__init__(self.siz, mem_manager_name)
125+
53126
self._init_buffers(
54127
self.size,
55128
dtype,
56129
head_num,
57130
head_dim,
58131
layer_num,
59132
)
60-
self.HOLD_TOKEN_MEMINDEX = self.size
61133

62134
def get_cell_size(self):
63135
return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype)
@@ -93,7 +165,7 @@ def alloc_kv_move_buffer(self, max_req_total_len):
93165
"""
94166
pd 分离模式使用的特殊接口
95167
"""
96-
if isinstance(self, MemoryManager) and type(self) != MemoryManager:
168+
if isinstance(self, MemoryManager) and type(self) is not MemoryManager:
97169
raise NotImplementedError("subclass need reimpl this method")
98170
self.kv_move_buffer = torch.empty(
99171
(1, max_req_total_len + 8, 2 * self.head_num, self.head_dim), dtype=self.dtype, device="cuda"
@@ -103,7 +175,7 @@ def alloc_kv_move_buffer(self, max_req_total_len):
103175
return
104176

105177
def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor:
106-
if isinstance(self, MemoryManager) and type(self) != MemoryManager:
178+
if isinstance(self, MemoryManager) and type(self) is not MemoryManager:
107179
raise NotImplementedError("subclass need reimpl this method")
108180

109181
num_kv_head = get_num_key_value_heads(get_env_start_args().model_dir)
@@ -320,59 +392,13 @@ def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: to
320392
def _free_buffers(self):
321393
self.kv_buffer = None
322394

323-
def alloc(self, need_size) -> torch.Tensor:
324-
if need_size > self.mark_end - self.mark_start:
325-
logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
326-
assert False, "error alloc state"
327-
328-
start = self.mark_start
329-
end = self.mark_start + need_size
330-
self.mark_start += need_size
331-
332-
self.can_use_mem_size -= need_size
333-
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
334-
335-
# 利用缓冲区返回,避免异步情况下的内存竞争
336-
if self._return_start + need_size > self._mem_state_return.shape[0]:
337-
self._return_start = 0
338-
ans = self._mem_state_return[self._return_start : self._return_start + need_size]
339-
ans.copy_(self.mem_state[start:end])
340-
self._return_start += need_size
341-
return ans
342-
343-
def free(self, free_index: Union[torch.Tensor, List[int]]):
344-
"""_summary_
345-
346-
Args:
347-
free_index (torch.Tensor): _description_
348-
"""
349-
350-
end = self.mark_start
351-
start = self.mark_start - len(free_index)
352-
assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}"
353-
354-
if isinstance(free_index, list):
355-
self.mem_state.numpy()[start:end] = free_index
356-
else:
357-
# 从 gpu 到 cpu 的拷贝操作是流内阻塞操作
358-
self.mem_state[start:end] = free_index
359-
360-
self.mark_start -= len(free_index)
361-
362-
self.can_use_mem_size += len(free_index)
363-
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
364-
365-
if self.can_use_mem_size == len(self.mem_state):
366-
logger.debug(f"freed all gpu mem size {self.can_use_mem_size}")
367-
return
395+
def get_index_kv_buffer(self, index):
396+
return {"kv_buffer": self.kv_buffer[:, index]}
368397

369-
def free_all(self):
370-
self.can_use_mem_size = len(self.mem_state)
371-
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
372-
self.mem_state.numpy()[:] = list(range(0, len(self.mem_state)))
373-
self.mark_start = 0
374-
self.mark_end = len(self.mem_state)
398+
def load_index_kv_buffer(self, index, load_tensor_dict):
399+
self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"])
375400

401+
# 重写resize_mem方法,添加_free_buffers和_init_buffers调用
376402
def resize_mem(self, new_size):
377403
"""
378404
just for test code
@@ -383,24 +409,13 @@ def resize_mem(self, new_size):
383409
head_dim = self.head_dim
384410
layer_num = self.layer_num
385411

386-
self.size = new_size
387-
self.mem_state = torch.arange(
388-
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
389-
)
390-
self.mark_start = 0
391-
self.mark_end = self.size
392-
self.can_use_mem_size = self.size
393-
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
412+
# 调用父类的resize_mem
413+
super().resize_mem(new_size)
414+
394415
self._free_buffers()
395416
self._init_buffers(size, dtype, head_num, head_dim, layer_num)
396417
return
397418

398-
def get_index_kv_buffer(self, index):
399-
return {"kv_buffer": self.kv_buffer[:, index]}
400-
401-
def load_index_kv_buffer(self, index, load_tensor_dict):
402-
self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"])
403-
404419

405420
class ReadOnlyStaticsMemoryManager:
406421
"""

0 commit comments

Comments
 (0)