Skip to content

Commit 79b6fe3

Browse files
committed
fix
1 parent 415baea commit 79b6fe3

File tree

6 files changed

+67
-105
lines changed

6 files changed

+67
-105
lines changed

lightllm/server/api_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class CompletionRequest(BaseModel):
5757
# prompt: string or tokens
5858
prompt: Union[str, List[str], List[int], List[List[int]]]
5959
suffix: Optional[str] = None
60-
max_tokens: Optional[int] = 16
60+
max_tokens: Optional[int] = 16000
6161
temperature: Optional[float] = 1.0
6262
top_p: Optional[float] = 1.0
6363
n: Optional[int] = 1
@@ -89,7 +89,7 @@ class ChatCompletionRequest(BaseModel):
8989
stream: Optional[bool] = False
9090
stream_options: Optional[StreamOptions] = None
9191
stop: Optional[Union[str, List[str]]] = None
92-
max_tokens: Optional[int] = 16
92+
max_tokens: Optional[int] = 16000
9393
presence_penalty: Optional[float] = 0.0
9494
frequency_penalty: Optional[float] = 0.0
9595
logit_bias: Optional[Dict[str, float]] = None

lightllm/server/router/model_infer/mode_backend/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .chunked_prefill.impl import ChunkedPrefillBackend
22
from .chunked_prefill.impl_for_first_token_constraint_mode import FirstTokenConstraintBackend
33
from .chunked_prefill.impl_for_outlines_constraint_mode import OutlinesConstraintBackend
4-
from .chunked_prefill.impl_for_qwen3next import Qwen3NextBackend
4+
from .chunked_prefill.impl_for_hybrid_radix_cache import HybridRadixCacheBackend
55
from .chunked_prefill.impl_for_return_all_prompt_logprobs import ReturnPromptLogProbBackend
66
from .chunked_prefill.impl_for_reward_model import RewardModelBackend
77
from .chunked_prefill.impl_for_token_healing import TokenHealingBackend

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ def init_model(self, kvargs):
164164
self.model, self.is_multimodal = get_model(model_cfg, model_kvargs)
165165
self.model: TpPartBaseModel = self.model # for easy typing
166166
set_random_seed(2147483647)
167-
is_qwen3next = model_cfg.get("model_type", "") == "qwen3_next"
168-
radix_cache_class = RadixCache if not is_qwen3next else HybridRadixCache
167+
is_hybrid_model = model_cfg.get("model_type", "") in ["qwen3_next"]
168+
radix_cache_class = RadixCache if not is_hybrid_model else HybridRadixCache
169169
self.radix_cache = (
170170
radix_cache_class(
171171
get_unique_server_name(),
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch
2+
from .impl import ChunkedPrefillBackend
3+
from typing import List
4+
from typing_extensions import override
5+
from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq
6+
from lightllm.common.basemodel.infer_lock import g_infer_state_lock
7+
from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventPack
8+
from lightllm.server.router.model_infer.mode_backend.pre import (
9+
prepare_prefill_inputs,
10+
)
11+
from lightllm.utils.log_utils import init_logger
12+
13+
logger = init_logger(__name__)
14+
15+
16+
class HybridRadixCacheBackend(ChunkedPrefillBackend):
17+
def __init__(self) -> None:
18+
super().__init__()
19+
logger.info("Using HybridRadixCacheBackend for hybrid attention model.")
20+
self.extra_post_req_handle_func = self._handle_hybrid_radix_cache_insert
21+
22+
@override
23+
def init_model(self, kvargs):
24+
from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache
25+
super().init_model(kvargs)
26+
assert isinstance(self.radix_cache, HybridRadixCache)
27+
return
28+
29+
def _handle_hybrid_radix_cache_insert(self, req_obj: "InferReq", next_token_id, next_token_logprob):
30+
# TODO : add docs
31+
if (req_obj.is_multi_chat_req or
32+
req_obj.cur_kv_len >= req_obj.get_cur_total_len()):
33+
return
34+
35+
g_infer_state_lock.acquire()
36+
input_token_ids = req_obj.get_input_token_ids()
37+
key = torch.tensor(input_token_ids[0 : req_obj.cur_kv_len], dtype=torch.int64, device="cpu")
38+
39+
value = self.model.req_manager.req_to_token_indexs[req_obj.req_idx][: req_obj.cur_kv_len].cpu()
40+
41+
buffer_idx = self.model.req_manager.req_to_buffer_indexes[req_obj.req_idx].cpu()
42+
43+
self.radix_cache.free_radix_cache_to_get_enough_token(0, 1)
44+
45+
new_buffer_idx = self.model.req_manager.mem_manager.alloc_state_cache_buffer(1)[0]
46+
self.model.req_manager.mem_manager.copy_state_cache_buffer(buffer_idx, new_buffer_idx)
47+
self.model.req_manager.req_to_buffer_indexes[req_obj.req_idx] = new_buffer_idx
48+
49+
_, new_shared_kv_node = self.radix_cache.insert(key, value, buffer_idx)
50+
51+
self.radix_cache.dec_node_ref_counter(req_obj.shared_kv_node)
52+
self.radix_cache.add_node_ref_counter(new_shared_kv_node)
53+
req_obj.shared_kv_node = new_shared_kv_node
54+
g_infer_state_lock.release()

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_qwen3next.py

Lines changed: 0 additions & 96 deletions
This file was deleted.

lightllm/server/router/model_infer/model_rpc.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
import setproctitle
88
from datetime import timedelta
99
from typing import Dict, List, Tuple
10+
from transformers import PretrainedConfig
1011
from lightllm.server.router.model_infer.mode_backend import (
1112
ChunkedPrefillBackend,
1213
FirstTokenConstraintBackend,
1314
OutlinesConstraintBackend,
14-
Qwen3NextBackend,
15+
HybridRadixCacheBackend,
1516
ReturnPromptLogProbBackend,
1617
RewardModelBackend,
1718
TokenHealingBackend,
@@ -121,14 +122,17 @@ def init_model(self, kvargs):
121122
is_outlines_constraint_mode = self.args.output_constraint_mode == "outlines"
122123
is_xgrammar_constraint_mode = self.args.output_constraint_mode == "xgrammar"
123124
assert not (is_outlines_constraint_mode and is_xgrammar_constraint_mode), "only one constraint mode can be true"
124-
is_qwen3next = True
125125
is_prefill_node = self.args.run_mode == "prefill"
126126
is_decode_node = self.args.run_mode == "decode"
127127
is_nixl_prefill_node = self.args.run_mode == "nixl_prefill"
128128
is_nixl_decode_node = self.args.run_mode == "nixl_decode"
129129

130-
if is_qwen3next:
131-
self.backend = Qwen3NextBackend()
130+
model_cfg, _ = PretrainedConfig.get_config_dict(kvargs["weight_dir"])
131+
is_hybrid_model = model_cfg.get("model_type", "") in ["qwen3_next"]
132+
use_hybrid_radix_cache = is_hybrid_model and not self.args.disable_dynamic_prompt_cache
133+
134+
if use_hybrid_radix_cache:
135+
self.backend = HybridRadixCacheBackend()
132136
elif is_prefill_node:
133137
if self.args.dp > 1:
134138
self.backend = DPChunkedForPrefillNode(self.info_queue, self.mem_queue)

0 commit comments

Comments
 (0)