diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 1973d670d..3aee08769 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -614,6 +614,7 @@ async def _collect_generation_results( "text": request_output, "logprob": metadata.get("logprob", None), "id": metadata.get("id", None), + "top_logprobs": metadata.get("top_logprobs", None), } token_infos.append(token_info) @@ -693,16 +694,30 @@ def _build_logprobs_data(result: Dict, request: CompletionRequest, tokenizer) -> all_tokens = [] all_token_logprobs = [] all_text_offsets = [] + all_top_logprobs = [] offset = 0 def add_tokens_to_logprobs(token_ids=None, token_infos=None, logprob_map=None): nonlocal offset - def add_single_token(token_text: str, logprob: float): + def add_single_token(token_text: str, logprob: float, top_logprobs: List[Dict[int, float]] = None): nonlocal offset all_tokens.append(token_text) all_token_logprobs.append(logprob) all_text_offsets.append(offset) + if top_logprobs is not None: + formatted_top_logprobs = {} + for item in top_logprobs: + for t_id, t_prob in item.items(): + t_text = tokenizer.decode([t_id], skip_special_tokens=False) + formatted_top_logprobs[t_text] = t_prob + all_top_logprobs.append(formatted_top_logprobs) + else: + if logprob is not None: + all_top_logprobs.append({token_text: logprob}) + else: + all_top_logprobs.append(None) + offset += len(token_text) if token_ids is not None: @@ -712,7 +727,7 @@ def add_single_token(token_text: str, logprob: float): add_single_token(token_text, logprob) elif token_infos is not None: for token_info in token_infos: - add_single_token(token_info["text"], token_info["logprob"]) + add_single_token(token_info["text"], token_info["logprob"], token_info.get("top_logprobs", None)) # 处理 echo 模式下的 prompt tokens if request.echo and result.get("prompt_logprobs") is not None: @@ -743,18 +758,9 @@ def add_single_token(token_text: str, logprob: float): if result.get("token_infos"): add_tokens_to_logprobs(token_infos=result["token_infos"]) - top_logprobs_list = [] - for i, (token, logprob) in enumerate(zip(all_tokens, all_token_logprobs)): - if logprob is not None: - # TODO: 标准实现需要从后端获取top-k个logprobs数据 - # 目前后端不支持,只能获取所选token的logprobs - top_logprobs_list.append({token: logprob}) - else: - top_logprobs_list.append(None) - return { "tokens": all_tokens, "token_logprobs": all_token_logprobs, - "top_logprobs": top_logprobs_list, + "top_logprobs": all_top_logprobs, "text_offset": all_text_offsets, } diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 947f24644..a8cdd91fc 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -16,6 +16,8 @@ logger = init_logger(__name__) +MAX_TOP_K_LOGPROBS = 20 + class FinishStatus(ctypes.Structure): _pack_ = 4 @@ -170,6 +172,7 @@ def init( self.input_len = len(prompt_ids) self.alloc_shm_numpy_len = self.input_len + self.sample_params.max_new_tokens + 1024 # + 1024 for safe self.create_logprobs_shm_array() + self.create_top_logprobs_shm_array() self.create_prompt_ids_shm_array() self.chunked_prefill_size = chunked_prefill_size self.shm_prompt_ids.arr[0 : len(prompt_ids)] = prompt_ids @@ -218,6 +221,17 @@ def create_logprobs_shm_array(self): self.shm_logprobs.create_shm() return + def create_top_logprobs_shm_array(self): + service_uni_name = get_unique_server_name() + name_ids = f"{service_uni_name}_shm_top_logprobs_ids_{self.index_in_shm_mem}" + self.shm_top_logprobs_ids = ShmArray(name_ids, (self.alloc_shm_numpy_len, MAX_TOP_K_LOGPROBS), dtype=np.int32) + self.shm_top_logprobs_ids.create_shm() + + name_val = f"{service_uni_name}_shm_top_logprobs_val_{self.index_in_shm_mem}" + self.shm_top_logprobs_val = ShmArray(name_val, (self.alloc_shm_numpy_len, MAX_TOP_K_LOGPROBS), dtype=np.float32) + self.shm_top_logprobs_val.create_shm() + return + def link_logprobs_shm_array(self): service_uni_name = get_unique_server_name() name = f"{service_uni_name}_shm_logprobs_{self.index_in_shm_mem}" @@ -225,6 +239,17 @@ def link_logprobs_shm_array(self): self.shm_logprobs.link_shm() return + def link_top_logprobs_shm_array(self): + service_uni_name = get_unique_server_name() + name_ids = f"{service_uni_name}_shm_top_logprobs_ids_{self.index_in_shm_mem}" + self.shm_top_logprobs_ids = ShmArray(name_ids, (self.alloc_shm_numpy_len, MAX_TOP_K_LOGPROBS), dtype=np.int32) + self.shm_top_logprobs_ids.link_shm() + + name_val = f"{service_uni_name}_shm_top_logprobs_val_{self.index_in_shm_mem}" + self.shm_top_logprobs_val = ShmArray(name_val, (self.alloc_shm_numpy_len, MAX_TOP_K_LOGPROBS), dtype=np.float32) + self.shm_top_logprobs_val.link_shm() + return + def get_prompt_ids(self): return self.shm_prompt_ids.arr[: self.input_len].tolist() diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 11919398e..e66de2f8a 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -25,6 +25,7 @@ from lightllm.server.core.objs import Req, FinishStatus, StartArgs from lightllm.server.core.objs import SamplingParams from lightllm.server.core.objs.out_token_circlequeue import LIGHTLLM_OUT_TOKEN_QUEUE_SIZE +from lightllm.server.core.objs.req import MAX_TOP_K_LOGPROBS from lightllm.server.core.objs.io_objs import GroupReqObjs from lightllm.server.core.objs.shm_req_manager import ShmReqManager from lightllm.server.core.objs.atomic_array_lock import AtomicShmArrayLock, AsyncLock, AtomicLockItem @@ -730,6 +731,17 @@ async def handle_loop(self): "cpu_prompt_cache_len": req.cpu_prompt_cache_len, "mtp_accepted_token_num": req.mtp_accepted_token_num, } + + top_k_ids = req.shm_top_logprobs_ids.arr[src_index] + top_k_vals = req.shm_top_logprobs_val.arr[src_index] + top_logprobs = [] + for i in range(MAX_TOP_K_LOGPROBS): + if top_k_vals[i] == -float("inf"): + break + top_logprobs.append({int(top_k_ids[i]): float(top_k_vals[i])}) + if top_logprobs: + metadata["top_logprobs"] = top_logprobs + if self.args.return_all_prompt_logprobs: metadata.update(req.get_all_prompt_metadata()) if self.args.use_reward_model: diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 3fe3f5136..804e82ac5 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -10,6 +10,7 @@ from lightllm.common.req_manager import ReqManager from lightllm.utils.infer_utils import mark_start, mark_end from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager +from lightllm.server.core.objs.req import MAX_TOP_K_LOGPROBS from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode from lightllm.utils.log_utils import init_logger from lightllm.server.req_id_generator import convert_sub_id_to_group_id @@ -361,6 +362,7 @@ def _init_all_state(self): self.shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(self.shm_index) self.shm_req.link_prompt_ids_shm_array() self.shm_req.link_logprobs_shm_array() + self.shm_req.link_top_logprobs_shm_array() self.sampling_param: InferSamplingParams = InferSamplingParams(self.shm_req, self.vocab_size) # 更新 nixl pd 分离模式下, prefill 节点需要开始传输的起始位置 @@ -453,10 +455,26 @@ def get_chuncked_input_token_len(self): chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size) return chunked_end - def set_next_gen_token_id(self, next_token_id: int, logprob: float, output_len: int): + def set_next_gen_token_id( + self, + next_token_id: int, + logprob: float, + output_len: int, + top_k_ids: List[int] = None, + top_k_logprobs: List[float] = None, + ): index = self.shm_req.input_len + output_len self.shm_req.shm_prompt_ids.arr[index - 1] = next_token_id self.shm_req.shm_logprobs.arr[index - 1] = logprob + + if top_k_ids is not None and top_k_logprobs is not None: + k = min(len(top_k_ids), MAX_TOP_K_LOGPROBS) + self.shm_req.shm_top_logprobs_ids.arr[index - 1, :k] = top_k_ids[:k] + self.shm_req.shm_top_logprobs_val.arr[index - 1, :k] = top_k_logprobs[:k] + # Zero out the rest if any + if k < MAX_TOP_K_LOGPROBS: + self.shm_req.shm_top_logprobs_ids.arr[index - 1, k:] = 0 + self.shm_req.shm_top_logprobs_val.arr[index - 1, k:] = -float("inf") return def update_mtp_accepted_token_num(self, accept_token_num: int): @@ -528,6 +546,8 @@ def handle( extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]], is_master_in_dp: bool, nixl_prefill_chuncked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None, + top_k_ids: List[int] = None, + top_k_logprobs: List[float] = None, ): # nixl_prefill_chuncked_handle_func 主要是为了处理 nixl prefill 模式下 # 分块 prefill 后,形成对应的pd 分块传输处理。 @@ -540,7 +560,7 @@ def handle( req_obj = self.req_obj shm_req = req_obj.shm_req finish_status = req_obj.finish_status - req_obj.set_next_gen_token_id(next_token_id, next_token_logprob, self.output_len) + req_obj.set_next_gen_token_id(next_token_id, next_token_logprob, self.output_len, top_k_ids, top_k_logprobs) # 这里提前判定的主要作用是: # 在 mtp mode 下,可以存在同一个 req 对象的多次处理, diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 95f0c9951..501cc7317 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -288,7 +288,13 @@ def init_mtp_draft_model(self, main_kvargs: dict): self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") return - def _async_copy_next_token_infos_to_pin_mem(self, next_token_ids: torch.Tensor, next_token_logprobs: torch.Tensor): + def _async_copy_next_token_infos_to_pin_mem( + self, + next_token_ids: torch.Tensor, + next_token_logprobs: torch.Tensor, + top_k_ids: torch.Tensor = None, + top_k_logprobs: torch.Tensor = None, + ): """ 这个函数会把next token id和logprobs保存到pinned memory中 这样可以保障post_handle 函数可以读取到正常的输出结果。 @@ -301,7 +307,20 @@ def _async_copy_next_token_infos_to_pin_mem(self, next_token_ids: torch.Tensor, key="next_token_logprobs", gpu_tensor=next_token_logprobs, ) - return next_token_ids_cpu, next_token_logprobs_cpu + + top_k_ids_cpu = None + top_k_logprobs_cpu = None + if top_k_ids is not None: + top_k_ids_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( + key="top_k_ids", + gpu_tensor=top_k_ids, + ) + top_k_logprobs_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( + key="top_k_logprobs", + gpu_tensor=top_k_logprobs, + ) + + return next_token_ids_cpu, next_token_logprobs_cpu, top_k_ids_cpu, top_k_logprobs_cpu def _try_read_new_reqs(self): if self.is_multinode_tp: @@ -646,19 +665,27 @@ def _post_handle( run_reqs_update_packs: List[InferReqUpdatePack], extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, nixl_prefill_chuncked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None, + top_k_ids: List[List[int]] = None, + top_k_logprobs: List[List[float]] = None, ): """ extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于 约束输出等模式,设置自己请求内部的状态机的状态,并添加额外的停止判定条件等。 """ - for req_obj, next_token_id, next_token_logprob, pack in zip( - run_reqs, next_token_ids, next_token_logprobs, run_reqs_update_packs + if top_k_ids is None: + top_k_ids = [None] * len(run_reqs) + top_k_logprobs = [None] * len(run_reqs) + + for req_obj, next_token_id, next_token_logprob, cur_top_k_ids, cur_top_k_logprobs, pack in zip( + run_reqs, next_token_ids, next_token_logprobs, top_k_ids, top_k_logprobs, run_reqs_update_packs ): req_obj: InferReq = req_obj pack: InferReqUpdatePack = pack pack.handle( next_token_id=next_token_id, next_token_logprob=next_token_logprob, + top_k_ids=cur_top_k_ids, + top_k_logprobs=cur_top_k_logprobs, eos_ids=self.eos_id, extra_post_req_handle_func=extra_post_req_handle_func, is_master_in_dp=self.is_master_in_dp, @@ -724,7 +751,7 @@ def _sample_and_scatter_token( assert len(run_reqs) == logits.shape[0] mask_func(run_reqs, logits) - next_token_ids, next_token_logprobs = sample(logits, run_reqs, self.eos_id) + next_token_ids, next_token_logprobs, top_k_ids, top_k_logprobs = sample(logits, run_reqs, self.eos_id) b_has_out = None if is_prefill: b_has_out = g_pin_mem_manager.gen_from_list( @@ -743,10 +770,13 @@ def _sample_and_scatter_token( next_token_ids=next_token_ids, mask=b_has_out, ) - next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( - next_token_ids, next_token_logprobs - ) - return next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu + ( + next_token_ids_cpu, + next_token_logprobs_cpu, + top_k_ids_cpu, + top_k_logprobs_cpu, + ) = self._async_copy_next_token_infos_to_pin_mem(next_token_ids, next_token_logprobs, top_k_ids, top_k_logprobs) + return next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu, top_k_ids_cpu, top_k_logprobs_cpu def _dp_all_gather_prefill_and_decode_req_num( self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq] diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 8d47d1057..5aab024e9 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -122,7 +122,13 @@ def prefill_normal( with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) if run_reqs_num > 0: - _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( + ( + _, + next_token_ids_cpu, + next_token_logprobs_cpu, + top_k_ids_cpu, + top_k_logprobs_cpu, + ) = self._sample_and_scatter_token( logits=model_output.logits[:run_reqs_num], b_req_idx=model_input.b_req_idx[:run_reqs_num], b_mtp_index=model_input.b_mtp_index[:run_reqs_num], @@ -149,6 +155,8 @@ def prefill_normal( run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, + top_k_ids=top_k_ids_cpu, + top_k_logprobs=top_k_logprobs_cpu, ) # 第四阶段 event_pack.notify_pre_post_handle() @@ -165,7 +173,13 @@ def decode_normal(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) if run_reqs_num > 0: - _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( + ( + _, + next_token_ids_cpu, + next_token_logprobs_cpu, + top_k_ids_cpu, + top_k_logprobs_cpu, + ) = self._sample_and_scatter_token( logits=model_output.logits[:run_reqs_num], b_req_idx=model_input.b_req_idx[:run_reqs_num], b_mtp_index=model_input.b_mtp_index[:run_reqs_num], @@ -190,6 +204,8 @@ def decode_normal(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq next_token_logprobs=next_token_logprobs_cpu, run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, + top_k_ids=top_k_ids_cpu, + top_k_logprobs=top_k_logprobs_cpu, ) # 第四阶段 @@ -230,7 +246,13 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer if (req_num0 + req_num1) > 0: - _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( + ( + _, + next_token_ids_cpu, + next_token_logprobs_cpu, + top_k_ids_cpu, + top_k_logprobs_cpu, + ) = self._sample_and_scatter_token( logits=logits, b_req_idx=b_req_idx, b_mtp_index=b_mtp_index, @@ -258,6 +280,8 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, + top_k_ids=top_k_ids_cpu, + top_k_logprobs=top_k_logprobs_cpu, ) # 第四阶段 event_pack.notify_pre_post_handle() @@ -294,7 +318,13 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe run_reqs = run_reqs0 + run_reqs1 if (req_num0 + req_num1) > 0: - _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( + ( + _, + next_token_ids_cpu, + next_token_logprobs_cpu, + top_k_ids_cpu, + top_k_logprobs_cpu, + ) = self._sample_and_scatter_token( logits=logits, b_req_idx=b_req_idx, b_mtp_index=b_mtp_index, @@ -319,6 +349,8 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe next_token_logprobs=next_token_logprobs_cpu, run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, + top_k_ids=top_k_ids_cpu, + top_k_logprobs=top_k_logprobs_cpu, ) # 第四阶段 @@ -341,7 +373,13 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] b_mtp_index = model_input.b_mtp_index[0:req_num] if req_num > 0: - next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( + ( + next_token_ids, + next_token_ids_cpu, + next_token_logprobs_cpu, + top_k_ids_cpu, + top_k_logprobs_cpu, + ) = self._sample_and_scatter_token( logits=logits, b_req_idx=b_req_idx, b_mtp_index=b_mtp_index, @@ -380,6 +418,8 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, + top_k_ids=top_k_ids_cpu, + top_k_logprobs=top_k_logprobs_cpu, ) # 第四阶段 @@ -403,9 +443,14 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): b_mtp_index_cpu = b_mtp_index_cpu[0:req_num] b_req_idx = model_input.b_req_idx[0:req_num] - next_token_ids, next_token_logprobs = sample(logits, run_reqs, self.eos_id) - next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( - next_token_ids, next_token_logprobs + next_token_ids, next_token_logprobs, top_k_ids, top_k_logprobs = sample(logits, run_reqs, self.eos_id) + ( + next_token_ids_cpu, + next_token_logprobs_cpu, + top_k_ids_cpu, + top_k_logprobs_cpu, + ) = self._async_copy_next_token_infos_to_pin_mem( + next_token_ids, next_token_logprobs, top_k_ids, top_k_logprobs ) # verify the next_token_ids @@ -620,7 +665,13 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I b_req_idx = torch.cat((model_input0.b_req_idx[0:req_num0], model_input1.b_req_idx[0:req_num1]), dim=0) if (req_num0 + req_num1) > 0: - next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( + ( + next_token_ids, + next_token_ids_cpu, + next_token_logprobs_cpu, + top_k_ids_cpu, + top_k_logprobs_cpu, + ) = self._sample_and_scatter_token( logits=logits, run_reqs=run_reqs, b_req_idx=b_req_idx, @@ -680,6 +731,8 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, + top_k_ids=top_k_ids_cpu, + top_k_logprobs=top_k_logprobs_cpu, ) event_pack.notify_pre_post_handle() else: @@ -714,9 +767,14 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf ) logits[0:req_num0, :].copy_(logits0[0:req_num0, :], non_blocking=True) logits[req_num0 : (req_num0 + req_num1), :].copy_(logits1[0:req_num1, :], non_blocking=True) - next_token_ids, next_token_logprobs = sample(logits, run_reqs, self.eos_id) - next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( - next_token_ids, next_token_logprobs + next_token_ids, next_token_logprobs, top_k_ids, top_k_logprobs = sample(logits, run_reqs, self.eos_id) + ( + next_token_ids_cpu, + next_token_logprobs_cpu, + top_k_ids_cpu, + top_k_logprobs_cpu, + ) = self._async_copy_next_token_infos_to_pin_mem( + next_token_ids, next_token_logprobs, top_k_ids, top_k_logprobs ) b_req_idx = torch.cat((model_input0.b_req_idx[0:req_num0], model_input1.b_req_idx[0:req_num1]), dim=0) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index e2ccf290e..77d7439fa 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -72,7 +72,8 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): sampled_index = torch.multinomial(probs_sort, num_samples=1, replacement=True) next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index) next_token_logprobs = torch.log(torch.gather(probs_sort, dim=1, index=sampled_index)) - return next_token_ids.view(-1), next_token_logprobs.view(-1) + top_k_logprobs_val, top_k_logprobs_idx = _get_top_logprobs(probs, k=20) + return next_token_ids.view(-1), next_token_logprobs.view(-1), top_k_logprobs_idx, top_k_logprobs_val elif get_env_start_args().sampling_backend == "sglang_kernel": from sgl_kernel import top_k_top_p_sampling_from_probs @@ -103,6 +104,12 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor return probs_sort, probs_idx +def _get_top_logprobs(probs: torch.Tensor, k: int = 20): + top_k_logprobs_val, top_k_logprobs_idx = torch.topk(probs, k=k, dim=-1) + top_k_logprobs_val = torch.log(top_k_logprobs_val) + return top_k_logprobs_val, top_k_logprobs_idx + + def _get_post_sample_tensors(reqs: List[InferReq]): req_idxes: List[int] = [] temperatures: List[float] = []