|
11 | 11 | from lightllm.models.qwen3next.mem_manager import Qwen3NextMemoryManager |
12 | 12 | from lightllm.server.core.objs.start_args_type import StartArgs |
13 | 13 | from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput |
| 14 | +from lightllm.models.qwen3next.req_manager import Qwen3NextReqManager |
14 | 15 |
|
15 | 16 | logger = init_logger(__name__) |
16 | 17 |
|
@@ -83,18 +84,42 @@ def _init_mem_manager(self): |
83 | 84 | self.head_linear_k_dim, |
84 | 85 | self.head_linear_v_dim, |
85 | 86 | ), |
| 87 | + max_req_num=self.max_req_num, |
86 | 88 | mem_fraction=self.mem_fraction, |
87 | 89 | ) |
88 | 90 |
|
| 91 | + @override |
89 | 92 | def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0): |
90 | 93 | from lightllm.common.basemodel.infer_lock import g_infer_state_lock |
91 | | - from lightllm.common.basemodel.infer_context import g_infer_context |
| 94 | + from lightllm.server.router.model_infer.infer_batch import g_infer_context |
92 | 95 |
|
93 | 96 | infer_state = super()._create_inferstate(model_input, microbatch_index) |
| 97 | + |
| 98 | + buffer_indexes = self.req_manager.req_to_buffer_indexes[model_input.b_req_idx] |
| 99 | + empty_indexes = buffer_indexes == self.req_manager.EMPTY_BUFFER_INDEX |
| 100 | + num_empty = empty_indexes.sum() |
| 101 | + if num_empty == 0: |
| 102 | + return infer_state |
| 103 | + |
94 | 104 | g_infer_state_lock.acquire() |
95 | 105 | if g_infer_context.radix_cache is not None: |
96 | | - g_infer_context.radix_cache.free_radix_cache_to_get_enough_buffer(infer_state.batch_size) |
97 | | - buffer_indexes = self.mem_manager.alloc_buffer(infer_state.batch_size) |
| 106 | + g_infer_context.radix_cache.free_radix_cache_to_get_enough_buffer(num_empty) |
| 107 | + new_buffer_indexes = self.mem_manager.alloc_buffer(num_empty).cuda() |
98 | 108 | g_infer_state_lock.release() |
| 109 | + |
| 110 | + buffer_indexes[empty_indexes] = new_buffer_indexes |
| 111 | + self.req_manager.req_to_buffer_indexes[model_input.b_req_idx] = buffer_indexes |
99 | 112 | infer_state.buffer_indexes = buffer_indexes |
100 | 113 | return infer_state |
| 114 | + |
| 115 | + @override |
| 116 | + def _init_req_manager(self): |
| 117 | + create_max_seq_len = 0 |
| 118 | + |
| 119 | + if self.batch_max_tokens is not None: |
| 120 | + create_max_seq_len = max(create_max_seq_len, self.batch_max_tokens) |
| 121 | + if self.max_seq_length is not None: |
| 122 | + create_max_seq_len = max(create_max_seq_len, self.max_seq_length) |
| 123 | + |
| 124 | + self.req_manager = Qwen3NextReqManager(self.max_req_num, create_max_seq_len, self.mem_manager) |
| 125 | + return |
0 commit comments