Skip to content

Commit a4bc0d6

Browse files
authored
fix abort (#992)
1 parent 1c16247 commit a4bc0d6

File tree

5 files changed

+19
-21
lines changed

5 files changed

+19
-21
lines changed

lightllm/common/req_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def init_req_sampling_params(self, req):
155155
else:
156156
self.req_to_out_token_id_counter[req.req_idx].fill_(0)
157157
if req.sampling_param.shm_param.input_penalty and req.need_out_token_id_statistics:
158-
prompt_ids = torch.from_numpy(req.shm_req.get_prompt_ids()).pin_memory().cuda(non_blocking=True)
158+
prompt_ids = torch.from_numpy(req.shm_req.get_prompt_ids_numpy()).pin_memory().cuda(non_blocking=True)
159159
token_id_counter(
160160
prompt_ids=prompt_ids, out_token_id_counter=self.req_to_out_token_id_counter[req.req_idx]
161161
)

lightllm/server/core/objs/req.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,9 @@ def link_logprobs_shm_array(self):
188188
def get_prompt_ids(self):
189189
return self.shm_prompt_ids.arr[: self.input_len].tolist()
190190

191+
def get_prompt_ids_numpy(self):
192+
return self.shm_prompt_ids.arr[: self.input_len]
193+
191194
def to_router_rpc_obj(self):
192195
if hasattr(self, "multimodal_params"):
193196
return (

lightllm/server/router/req_queue/chunked_prefill/beam_impl.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,14 @@ def generate_new_batch(self, current_batch: Batch):
116116
if ok_insert:
117117
can_run_list.extend(cur_group_reqs)
118118

119+
new_batch = None
119120
if len(can_run_list) != 0:
120121
new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node)
121-
for req in abort_req_list:
122-
self.router.shm_req_manager.put_back_req_obj(req)
123122

124-
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
125-
return new_batch
126-
else:
127-
return None
123+
for req in abort_req_list:
124+
self.router.shm_req_manager.put_back_req_obj(req)
125+
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
126+
return new_batch
128127

129128
def _add_to_group(self, cur_group_reqs, req: Req):
130129
if len(cur_group_reqs) == 0:

lightllm/server/router/req_queue/chunked_prefill/impl.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,13 @@ def generate_new_batch(self, current_batch: Batch):
9191
can_run_list.append(req)
9292
else:
9393
break
94-
94+
new_batch = None
9595
if len(can_run_list) != 0:
9696
new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node)
97-
for req in abort_req_list:
98-
self.router.shm_req_manager.put_back_req_obj(req)
99-
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
100-
return new_batch
101-
else:
102-
return None
97+
for req in abort_req_list:
98+
self.router.shm_req_manager.put_back_req_obj(req)
99+
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
100+
return new_batch
103101

104102
def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch):
105103
is_busy = self.is_busy()

lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,13 @@ def generate_new_batch(self, current_batch: Batch):
4848
can_run_list.append(req)
4949
else:
5050
break
51-
51+
new_batch = None
5252
if len(can_run_list) != 0:
5353
new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node)
54-
for req in abort_req_list:
55-
self.router.shm_req_manager.put_back_req_obj(req)
56-
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
57-
return new_batch
58-
else:
59-
return None
54+
for req in abort_req_list:
55+
self.router.shm_req_manager.put_back_req_obj(req)
56+
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
57+
return new_batch
6058

6159
def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch):
6260
is_busy = self.is_busy()

0 commit comments

Comments
 (0)