Skip to content

Commit 4af3ac5

Browse files
committed
done
1 parent 00fb9d7 commit 4af3ac5

File tree

3 files changed

+24
-16
lines changed

3 files changed

+24
-16
lines changed

lightllm/models/qwen3next/model.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -90,25 +90,9 @@ def _init_mem_manager(self):
9090

9191
@override
9292
def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0):
93-
from lightllm.common.basemodel.infer_lock import g_infer_state_lock
94-
from lightllm.server.router.model_infer.infer_batch import g_infer_context
95-
9693
infer_state = super()._create_inferstate(model_input, microbatch_index)
9794

9895
buffer_indexes = self.req_manager.req_to_buffer_indexes[model_input.b_req_idx]
99-
empty_indexes = buffer_indexes == self.req_manager.EMPTY_BUFFER_INDEX
100-
num_empty = empty_indexes.sum()
101-
if num_empty == 0:
102-
return infer_state
103-
104-
g_infer_state_lock.acquire()
105-
if g_infer_context.radix_cache is not None:
106-
g_infer_context.radix_cache.free_radix_cache_to_get_enough_buffer(num_empty)
107-
new_buffer_indexes = self.mem_manager.alloc_buffer(num_empty).cuda()
108-
g_infer_state_lock.release()
109-
110-
buffer_indexes[empty_indexes] = new_buffer_indexes
111-
self.req_manager.req_to_buffer_indexes[model_input.b_req_idx] = buffer_indexes
11296
infer_state.buffer_indexes = buffer_indexes
11397
return infer_state
11498

lightllm/models/qwen3next/req_manager.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,23 @@ def free_buffer(self, free_req_indexes: List[int]):
3131
self.mem_manager.free_buffer(self.req_to_buffer_indexes[free_req_indexes])
3232
self.req_to_buffer_indexes[free_req_indexes] = self.EMPTY_BUFFER_INDEX
3333
return
34+
35+
def alloc_buffer(self, req_indexes: List[int]):
36+
from lightllm.common.basemodel.infer_lock import g_infer_state_lock
37+
from lightllm.server.router.model_infer.infer_batch import g_infer_context
38+
39+
cur_buffer_indexes = self.req_to_buffer_indexes[req_indexes]
40+
empty_indexes = cur_buffer_indexes == self.EMPTY_BUFFER_INDEX
41+
num_empty = empty_indexes.sum()
42+
if num_empty == 0:
43+
return
44+
45+
g_infer_state_lock.acquire()
46+
if g_infer_context.radix_cache is not None:
47+
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(num_empty)
48+
new_buffer_indexes = self.mem_manager.alloc_buffer(num_empty).cuda()
49+
g_infer_state_lock.release()
50+
51+
cur_buffer_indexes[empty_indexes] = new_buffer_indexes
52+
self.req_to_buffer_indexes[req_indexes] = cur_buffer_indexes
53+
return

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ def prefill_normal(
110110
model_input, run_reqs = prepare_prefill_inputs(
111111
prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal
112112
)
113+
114+
if hasattr(g_infer_context.req_manager, "req_to_buffer_indexes"):
115+
g_infer_context.req_manager.alloc_buffer(model_input.b_req_idx)
116+
113117
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
114118
model_output = self.model.forward(model_input)
115119
_, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token(

0 commit comments

Comments
 (0)