File tree Expand file tree Collapse file tree 3 files changed +24
-16
lines changed
server/router/model_infer/mode_backend/chunked_prefill Expand file tree Collapse file tree 3 files changed +24
-16
lines changed Original file line number Diff line number Diff line change @@ -90,25 +90,9 @@ def _init_mem_manager(self):
9090
9191 @override
9292 def _create_inferstate (self , model_input : ModelInput , microbatch_index : int = 0 ):
93- from lightllm .common .basemodel .infer_lock import g_infer_state_lock
94- from lightllm .server .router .model_infer .infer_batch import g_infer_context
95-
9693 infer_state = super ()._create_inferstate (model_input , microbatch_index )
9794
9895 buffer_indexes = self .req_manager .req_to_buffer_indexes [model_input .b_req_idx ]
99- empty_indexes = buffer_indexes == self .req_manager .EMPTY_BUFFER_INDEX
100- num_empty = empty_indexes .sum ()
101- if num_empty == 0 :
102- return infer_state
103-
104- g_infer_state_lock .acquire ()
105- if g_infer_context .radix_cache is not None :
106- g_infer_context .radix_cache .free_radix_cache_to_get_enough_buffer (num_empty )
107- new_buffer_indexes = self .mem_manager .alloc_buffer (num_empty ).cuda ()
108- g_infer_state_lock .release ()
109-
110- buffer_indexes [empty_indexes ] = new_buffer_indexes
111- self .req_manager .req_to_buffer_indexes [model_input .b_req_idx ] = buffer_indexes
11296 infer_state .buffer_indexes = buffer_indexes
11397 return infer_state
11498
Original file line number Diff line number Diff line change @@ -31,3 +31,23 @@ def free_buffer(self, free_req_indexes: List[int]):
3131 self .mem_manager .free_buffer (self .req_to_buffer_indexes [free_req_indexes ])
3232 self .req_to_buffer_indexes [free_req_indexes ] = self .EMPTY_BUFFER_INDEX
3333 return
34+
35+ def alloc_buffer (self , req_indexes : List [int ]):
36+ from lightllm .common .basemodel .infer_lock import g_infer_state_lock
37+ from lightllm .server .router .model_infer .infer_batch import g_infer_context
38+
39+ cur_buffer_indexes = self .req_to_buffer_indexes [req_indexes ]
40+ empty_indexes = cur_buffer_indexes == self .EMPTY_BUFFER_INDEX
41+ num_empty = empty_indexes .sum ()
42+ if num_empty == 0 :
43+ return
44+
45+ g_infer_state_lock .acquire ()
46+ if g_infer_context .radix_cache is not None :
47+ g_infer_context .radix_cache .free_radix_cache_to_get_enough_token (num_empty )
48+ new_buffer_indexes = self .mem_manager .alloc_buffer (num_empty ).cuda ()
49+ g_infer_state_lock .release ()
50+
51+ cur_buffer_indexes [empty_indexes ] = new_buffer_indexes
52+ self .req_to_buffer_indexes [req_indexes ] = cur_buffer_indexes
53+ return
Original file line number Diff line number Diff line change @@ -110,6 +110,10 @@ def prefill_normal(
110110 model_input , run_reqs = prepare_prefill_inputs (
111111 prefill_reqs , is_chuncked_mode = not self .disable_chunked_prefill , is_multimodal = self .is_multimodal
112112 )
113+
114+ if hasattr (g_infer_context .req_manager , "req_to_buffer_indexes" ):
115+ g_infer_context .req_manager .alloc_buffer (model_input .b_req_idx )
116+
113117 with torch .cuda .stream (g_infer_context .get_overlap_stream ()):
114118 model_output = self .model .forward (model_input )
115119 _ , next_token_ids_cpu , next_token_logprobs_cpu = self ._sample_and_scatter_token (
You can’t perform that action at this time.
0 commit comments