-
Notifications
You must be signed in to change notification settings - Fork 290
Asynchicache #977
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Asynchicache #977
Changes from 8 commits
e6265d1
382ab2f
f4bd76e
a0ae71d
9141e07
e3f4955
1a7e7d3
b55ca74
4f29672
ab5d933
caa2d6c
66fd1ab
b58dc31
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,178 @@ | ||
|
|
||
| import torch | ||
| from dataclasses import dataclass | ||
| import torch.multiprocessing as mp | ||
| from lightllm.utils.log_utils import init_logger | ||
| from typing import List, Union | ||
| from lightllm.utils.envs_utils import get_unique_server_name | ||
| from lightllm.utils.dist_utils import get_current_rank_in_node | ||
| from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt | ||
| from multiprocessing.managers import DictProxy, ListProxy | ||
| from multiprocessing import Manager | ||
|
|
||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| @dataclass | ||
| class SharedRadixMemoryData: | ||
| kv_buffer: torch.Tensor | ||
| mem_state: torch.Tensor | ||
| req_mem_index: DictProxy | ||
| lru_queue: ListProxy | ||
|
|
||
| @dataclass | ||
| class MemPropties: | ||
| size: int | ||
| dtype: torch.dtype | ||
| head_num: int | ||
| head_dim: int | ||
| layer_num: int | ||
|
|
||
| shared_mem_data: SharedRadixMemoryData = None | ||
|
|
||
|
|
||
| def init_shared_data(mem_propties: MemPropties, device="cuda"): | ||
| size, dtype, head_num, head_dim, layer_num = mem_propties.size, mem_propties.dtype, \ | ||
| mem_propties.head_num, mem_propties.head_dim, mem_propties.layer_num | ||
| global shared_mem_data | ||
|
|
||
| if device == "cuda": | ||
| kv_buffer = torch.empty( | ||
| (layer_num, size, head_num, head_dim), | ||
| dtype=dtype, | ||
| device="cuda" | ||
| ) | ||
| else: | ||
| kv_buffer = torch.empty( | ||
| (layer_num, size, head_num, head_dim), | ||
| dtype=dtype, | ||
| device="cpu" | ||
| ).share_memory_() | ||
|
|
||
| mem_state = torch.arange(size, dtype=torch.int32).share_memory_() | ||
| manager = Manager() | ||
| req_mem_index = manager.dict() | ||
| lru_queue = manager.list() | ||
|
|
||
| shared_mem_data = SharedRadixMemoryData( | ||
| kv_buffer=kv_buffer, | ||
| mem_state=mem_state, | ||
| req_mem_index=req_mem_index, | ||
| lru_queue=lru_queue | ||
| ) | ||
|
|
||
| def get_shared_data() -> SharedRadixMemoryData: | ||
| """Get the shared memory data.""" | ||
| global shared_mem_data | ||
| if shared_mem_data is None: | ||
| raise RuntimeError("Shared memory data has not been initialized. Call init_shared_data first.") | ||
| return shared_mem_data | ||
|
|
||
| class RadixMemoryBuffer: | ||
| def __init__(self, mem_propties: MemPropties, shared_data: SharedRadixMemoryData = None, lock: mp.Lock = None, device="cuda", | ||
| rank_in_node=None): | ||
| size, dtype, head_num, head_dim, layer_num = mem_propties.size, mem_propties.dtype, \ | ||
| mem_propties.head_num, mem_propties.head_dim, mem_propties.layer_num | ||
|
|
||
| self.kv_buffer = shared_data.kv_buffer | ||
| self.mem_state = shared_data.mem_state | ||
| self.req_mem_index = shared_data.req_mem_index | ||
| self.lock = lock if lock is not None else mp.Lock() | ||
|
|
||
| #TODO profile size | ||
| self.size = size # token slot 个数 | ||
| self.head_num = head_num | ||
| self.head_dim = head_dim | ||
| self.layer_num = layer_num | ||
| self.dtype = dtype | ||
|
|
||
| can_use_mem_size = self.size | ||
| mark_start = 0 | ||
| mark_end = self.size | ||
| rank_in_node = rank_in_node if rank_in_node is not None else get_current_rank_in_node() | ||
| self.can_use_mem_size = SharedInt( | ||
| f"{get_unique_server_name()}_radix_mem_manger_can_use_token_num_{rank_in_node}" | ||
| ) | ||
| self.can_use_mem_size.set_value(can_use_mem_size) | ||
| self.mark_start = SharedInt( | ||
| f"{get_unique_server_name()}_radix_mem_manger_mark_start_{rank_in_node}" | ||
| ) | ||
| self.mark_start.set_value(mark_start) | ||
|
|
||
| self.mark_end = SharedInt( | ||
| f"{get_unique_server_name()}_radix_mem_manger_mark_end_{rank_in_node}" | ||
| ) | ||
| self.mark_end.set_value(mark_end) | ||
| logger.info(f"create {get_unique_server_name()}_radix_mem_manger_can_use_token_num_{rank_in_node}") | ||
|
|
||
| def _free(self, free_index: Union[torch.Tensor, List[int]]): | ||
| """_summary_ | ||
| Args: | ||
| free_index (torch.Tensor): _description_ | ||
| """ | ||
| end = self.mark_start.get_value() | ||
| start = end - len(free_index) | ||
| assert start >= 0, f"error free state start: {end} free len {len(free_index)}" | ||
|
|
||
| if isinstance(free_index, list): | ||
| self.mem_state.numpy()[start:end] = free_index | ||
| else: | ||
| # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 | ||
| self.mem_state[start:end] = free_index | ||
|
Comment on lines
+119
to
+123
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are a few areas for improvement in this new file for better maintainability and clarity:
|
||
|
|
||
| self.mark_start.set_value(end - len(free_index)) | ||
|
|
||
| self.can_use_mem_size.set_value(self.can_use_mem_size.get_value() + len(free_index)) | ||
|
|
||
| if self.can_use_mem_size.get_value() == len(self.mem_state): | ||
| logger.debug(f"freed all gpu mem size {self.can_use_mem_size.get_value()}") | ||
| return | ||
|
|
||
| def free_req_index(self, req_id: int): | ||
| """Free the memory index for a specific request ID.""" | ||
| with self.lock: | ||
| if req_id not in self.req_mem_index: | ||
| logger.warning(f"Request ID {req_id} not found in memory index.") | ||
| return | ||
| index = self.req_mem_index[req_id] | ||
| self._free(index) | ||
| logger.info(f"Freed memory index for request {req_id} size {len(index)}, left size {self.can_use_mem_size.get_value()}") | ||
| del self.req_mem_index[req_id] | ||
|
|
||
| def alloc(self, need_size) -> torch.Tensor: | ||
| with self.lock: | ||
| if need_size > self.mark_end.get_value() - self.mark_start.get_value(): | ||
| logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size.get_value()}") | ||
| raise RuntimeError(f"Not enough memory to allocate {need_size} tokens.") | ||
|
|
||
| start = self.mark_start.get_value() | ||
| end = start + need_size | ||
| ans = self.mem_state[start:end] | ||
| self.mark_start.set_value(start + need_size) | ||
|
|
||
| self.can_use_mem_size.set_value(self.can_use_mem_size.get_value() - need_size) | ||
| return ans | ||
|
|
||
| def set_req_mem_index(self, req_id: int, index: List[int]): | ||
| """Set the memory index for a specific request ID.""" | ||
| with self.lock: | ||
| if req_id in self.req_mem_index: | ||
| logger.info(f"Request ID {req_id} already exists. Overwriting index {self.req_mem_index[req_id]} with {index}.") | ||
| self.req_mem_index[req_id] = index | ||
| logger.info(f"radix mem buffer insert req {req_id}, current disk work num {self._get_current_work_num()}") | ||
|
|
||
| def get_req_mem_index(self, req_id: int) -> List[int]: | ||
| """Get the memory index for a specific request ID.""" | ||
| with self.lock: | ||
| if req_id not in self.req_mem_index: | ||
| logger.warning(f"Request ID {req_id} not found. Returning empty list.") | ||
| return [] | ||
| return self.req_mem_index[req_id] | ||
|
|
||
| def get_kv_buffer(self, index) -> torch.Tensor: | ||
| with self.lock: | ||
| return self.kv_buffer[:, index, :, :] | ||
|
|
||
| def _get_current_work_num(self) -> int: | ||
| return len(self.req_mem_index) | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,119 @@ | ||||||
| import torch | ||||||
| import time | ||||||
| import xxhash | ||||||
| import numpy as np | ||||||
| from typing import List, Dict, Tuple, Optional | ||||||
| import torch.multiprocessing as mp | ||||||
| from collections import OrderedDict | ||||||
|
|
||||||
| from .radixmem_buffer import SharedRadixMemoryData, RadixMemoryBuffer | ||||||
|
|
||||||
| from lightllm.utils.log_utils import init_logger | ||||||
| logger = init_logger(__name__) | ||||||
|
|
||||||
| class RadixBufferManager: | ||||||
|
|
||||||
| def __init__(self, | ||||||
| radix_buffer: RadixMemoryBuffer = None, | ||||||
| radix_mem_data: SharedRadixMemoryData = None, | ||||||
| lock: Optional[mp.Lock] = None, | ||||||
| max_entries: int = 10000, | ||||||
| chunk_size: int = 64 | ||||||
| ): | ||||||
| self.chunk_size = chunk_size | ||||||
| self.max_entries = max_entries | ||||||
| self.radix_buffer = radix_buffer | ||||||
| self.lru_queue = radix_mem_data.lru_queue | ||||||
|
|
||||||
| self.lock = lock if lock is not None else mp.Lock() | ||||||
|
|
||||||
| def _compute_hash(self, tokens: List[int]) -> List[Tuple[int, List[int]]]: | ||||||
| chunks = [] | ||||||
| hsum = xxhash.xxh3_64() | ||||||
| cumulative_tokens = [] | ||||||
|
|
||||||
| for i in range(0, len(tokens), self.chunk_size): | ||||||
| chunk = tokens[i:i + self.chunk_size] | ||||||
| cumulative_tokens.extend(chunk) | ||||||
|
|
||||||
| chunk_np = np.array(chunk, dtype=np.uint32) | ||||||
| hsum.update(chunk_np.tobytes()) | ||||||
|
|
||||||
| current_hash = hsum.intdigest() | ||||||
| chunks.append((current_hash, cumulative_tokens.copy())) | ||||||
|
|
||||||
| return chunks | ||||||
|
|
||||||
| def write(self, tokens: List[int], values: torch.Tensor, start_pos: int) -> None: | ||||||
| with self.lock: | ||||||
| index = start_pos // self.chunk_size | ||||||
| chunks = self._compute_hash(tokens) | ||||||
|
|
||||||
| values = values[index * self.chunk_size:] | ||||||
| chunks = chunks[index:] | ||||||
| for i, (hash_val, _) in enumerate(chunks): | ||||||
| if hash not in self.radix_buffer.req_mem_index: | ||||||
|
||||||
| if hash not in self.radix_buffer.req_mem_index: | |
| if hash_val not in self.radix_buffer.req_mem_index: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code for initializing the
hiradix_cacheis duplicated in multiple model files, specifically:lightllm/models/deepseek2/model.pylightllm/models/qwen2/model.pyThis duplication makes the code harder to maintain and increases the risk of inconsistencies if changes are needed in the future. To improve maintainability, this logic should be refactored into a single, reusable helper method within the
TpPartBaseModelclass.