Skip to content

Commit cb4a8cc

Browse files
committed
support qwen3next prefix cache
1 parent 5e2375d commit cb4a8cc

40 files changed

+5465
-9
lines changed

lightllm/common/basemodel/layer_weights/hf_load_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_laye
6060
transformer_layer_list=transformer_layer_list,
6161
weight_dir=weight_dir,
6262
) # noqa
63-
worker = int(os.environ.get("LOADWORKER", 1))
63+
worker = int(os.environ.get("LOADWORKER", 16))
6464
with Pool(worker) as p:
6565
iterator = p.imap_unordered(partial_func, candidate_files, chunksize=1)
6666
desc_str = f"pid {os.getpid()} Loading model weights with {worker} workers"

lightllm/common/basemodel/layer_weights/meta_weights/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight
1010
from .fused_moe_weight_tp import create_tp_moe_wegiht_obj
1111
from .fused_moe_weight_ep import FusedMoeWeightEP
12+
from .parameter_weight import ParameterWeight, TpParameterWeight
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
from typing import Dict
3+
from .base_weight import BaseWeightTpl
4+
from lightllm.utils.dist_utils import get_current_device_id
5+
6+
7+
class ParameterWeight(BaseWeightTpl):
8+
def __init__(self, weight_name: str, data_type: torch.dtype, bias_name: str = None):
9+
super().__init__()
10+
self.weight_name = weight_name
11+
self.bias_name = bias_name
12+
self.data_type_ = data_type
13+
self.weight = None
14+
self.bias = None
15+
16+
def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
17+
if self.weight_name in weights:
18+
self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id())
19+
if self.bias_name in weights:
20+
self.bias = weights[self.bias_name].to(self.data_type_).cuda(get_current_device_id())
21+
22+
def verify_load(self):
23+
load_ok = True
24+
# Verify weight. The weight must be not None.
25+
load_ok = load_ok and self.weight is not None
26+
# Verify bias. If bias_name is set, it must be not None.
27+
if self.bias_name is not None:
28+
load_ok = load_ok and self.bias is not None
29+
return load_ok
30+
31+
32+
class TpParameterWeight(ParameterWeight):
33+
def __init__(self, weight_name: str, data_type: torch.dtype, split_n_embed: int, bias_name: str = None):
34+
super().__init__(weight_name, data_type, bias_name)
35+
self.split_n_embed = split_n_embed
36+
37+
def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
38+
start = self.split_n_embed * self.tp_rank_
39+
end = self.split_n_embed * (self.tp_rank_ + 1)
40+
41+
if self.weight_name in weights:
42+
self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(get_current_device_id())
43+
if self.bias_name in weights:
44+
self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(get_current_device_id())

lightllm/common/mem_manager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def free(self, free_index: Union[torch.Tensor, List[int]]):
6969
Args:
7070
free_index (torch.Tensor): _description_
7171
"""
72-
7372
end = self.mark_start
7473
start = self.mark_start - len(free_index)
7574
assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}"
@@ -121,7 +120,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
121120
self.dtype = dtype
122121
# profile the max total token num if the size is None
123122
self.profile_size(mem_fraction)
124-
super().__init__(self.siz, mem_manager_name)
123+
super().__init__(self.size, mem_manager_name)
125124

126125
self._init_buffers(
127126
self.size,

lightllm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from lightllm.models.qwen2.model import Qwen2TpPartModel
99
from lightllm.models.qwen3.model import Qwen3TpPartModel
1010
from lightllm.models.qwen3_moe.model import Qwen3MOEModel
11+
from lightllm.models.qwen3next.model import Qwen3NextTpPartModel
1112
from lightllm.models.chatglm2.model import ChatGlm2TpPartModel
1213
from lightllm.models.internlm.model import InternlmTpPartModel
1314
from lightllm.models.stablelm.model import StablelmTpPartModel

lightllm/models/qwen2/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, kvargs):
1717

1818
def _init_config(self):
1919
super()._init_config()
20-
if self.config["sliding_window"] is None:
20+
if self.config.get("sliding_window") is None:
2121
self.config["sliding_window"] = self.max_total_token_num
2222
# rename key [SYM: to be confirmed]
2323
return
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import os
2+
import torch
3+
import torch.functional as F
4+
import torch.distributed as dist
5+
import numpy as np
6+
7+
from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer
8+
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
9+
from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward
10+
11+
class Qwen3NextPostLayerInfer(LlamaPostLayerInfer):
12+
13+
def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor:
14+
out = self.alloc_tensor(input.shape, input.dtype)
15+
gemma_rmsnorm_forward(input, layer_weight.final_norm_weight_, self.eps_, out=out)
16+
return out

0 commit comments

Comments
 (0)