File tree Expand file tree Collapse file tree 5 files changed +19
-21
lines changed
router/req_queue/chunked_prefill Expand file tree Collapse file tree 5 files changed +19
-21
lines changed Original file line number Diff line number Diff line change @@ -155,7 +155,7 @@ def init_req_sampling_params(self, req):
155155 else :
156156 self .req_to_out_token_id_counter [req .req_idx ].fill_ (0 )
157157 if req .sampling_param .shm_param .input_penalty and req .need_out_token_id_statistics :
158- prompt_ids = torch .from_numpy (req .shm_req .get_prompt_ids ()).pin_memory ().cuda (non_blocking = True )
158+ prompt_ids = torch .from_numpy (req .shm_req .get_prompt_ids_numpy ()).pin_memory ().cuda (non_blocking = True )
159159 token_id_counter (
160160 prompt_ids = prompt_ids , out_token_id_counter = self .req_to_out_token_id_counter [req .req_idx ]
161161 )
Original file line number Diff line number Diff line change @@ -188,6 +188,9 @@ def link_logprobs_shm_array(self):
188188 def get_prompt_ids (self ):
189189 return self .shm_prompt_ids .arr [: self .input_len ].tolist ()
190190
191+ def get_prompt_ids_numpy (self ):
192+ return self .shm_prompt_ids .arr [: self .input_len ]
193+
191194 def to_router_rpc_obj (self ):
192195 if hasattr (self , "multimodal_params" ):
193196 return (
Original file line number Diff line number Diff line change @@ -116,15 +116,14 @@ def generate_new_batch(self, current_batch: Batch):
116116 if ok_insert :
117117 can_run_list .extend (cur_group_reqs )
118118
119+ new_batch = None
119120 if len (can_run_list ) != 0 :
120121 new_batch = Batch (uuid .uuid4 ().int , can_run_list , dp_size_in_node = self .dp_size_in_node )
121- for req in abort_req_list :
122- self .router .shm_req_manager .put_back_req_obj (req )
123122
124- self . waiting_req_list = self . waiting_req_list [ len ( can_run_list ) + aborted_count :]
125- return new_batch
126- else :
127- return None
123+ for req in abort_req_list :
124+ self . router . shm_req_manager . put_back_req_obj ( req )
125+ self . waiting_req_list = self . waiting_req_list [ len ( can_run_list ) + aborted_count :]
126+ return new_batch
128127
129128 def _add_to_group (self , cur_group_reqs , req : Req ):
130129 if len (cur_group_reqs ) == 0 :
Original file line number Diff line number Diff line change @@ -91,15 +91,13 @@ def generate_new_batch(self, current_batch: Batch):
9191 can_run_list .append (req )
9292 else :
9393 break
94-
94+ new_batch = None
9595 if len (can_run_list ) != 0 :
9696 new_batch = Batch (uuid .uuid4 ().int , can_run_list , dp_size_in_node = self .dp_size_in_node )
97- for req in abort_req_list :
98- self .router .shm_req_manager .put_back_req_obj (req )
99- self .waiting_req_list = self .waiting_req_list [len (can_run_list ) + aborted_count :]
100- return new_batch
101- else :
102- return None
97+ for req in abort_req_list :
98+ self .router .shm_req_manager .put_back_req_obj (req )
99+ self .waiting_req_list = self .waiting_req_list [len (can_run_list ) + aborted_count :]
100+ return new_batch
103101
104102 def _calcu_batch_token_load_batch_not_none (self , current_batch : Batch ):
105103 is_busy = self .is_busy ()
Original file line number Diff line number Diff line change @@ -48,15 +48,13 @@ def generate_new_batch(self, current_batch: Batch):
4848 can_run_list .append (req )
4949 else :
5050 break
51-
51+ new_batch = None
5252 if len (can_run_list ) != 0 :
5353 new_batch = Batch (uuid .uuid4 ().int , can_run_list , dp_size_in_node = self .dp_size_in_node )
54- for req in abort_req_list :
55- self .router .shm_req_manager .put_back_req_obj (req )
56- self .waiting_req_list = self .waiting_req_list [len (can_run_list ) + aborted_count :]
57- return new_batch
58- else :
59- return None
54+ for req in abort_req_list :
55+ self .router .shm_req_manager .put_back_req_obj (req )
56+ self .waiting_req_list = self .waiting_req_list [len (can_run_list ) + aborted_count :]
57+ return new_batch
6058
6159 def _calcu_batch_token_load_batch_not_none (self , current_batch : Batch ):
6260 is_busy = self .is_busy ()
You can’t perform that action at this time.
0 commit comments