Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
defa581
[add]add whisper sdpa
SangChengC Sep 26, 2025
87c15dc
[add]add qwen3-vl-moe support
Oct 22, 2025
c318d72
fix1103
Nov 3, 2025
2ebdb58
add qwen3-vl support
Nov 21, 2025
1588ff3
1203
Dec 3, 2025
cd9c7ee
Merge branch 'main' into add-qwen3-vl
Dec 4, 2025
3ee963e
1204
Dec 4, 2025
0da89eb
1210
Dec 10, 2025
02486eb
1210
Dec 10, 2025
f6c5d64
Merge branch 'main' into add-qwen3-vl
Dec 10, 2025
1902799
1210
Dec 10, 2025
ebd5f7c
1210
Dec 10, 2025
ae29f70
1210
Dec 10, 2025
544f625
1210
Dec 10, 2025
46d2414
fix-qwen2-vl-mrope-pos-id
Dec 12, 2025
79c1fcf
mrope refactor (chunkedprefill waiting to verify)
shihaobai Dec 14, 2025
95fc1d4
fix chunked prefill
shihaobai Dec 14, 2025
29fd280
improve mrope
shihaobai Dec 15, 2025
4412556
add vlm acc benchmark
shihaobai Dec 15, 2025
c08eb0c
remove comment
shihaobai Dec 15, 2025
fd380f0
Merge branch 'qwen2-vl-mrope-fix' into add-qwen3-vl
Dec 15, 2025
e7da666
remove blocking ops
shihaobai Dec 15, 2025
f4d10cf
fix start_idx used
Dec 15, 2025
abfb4ec
fix tap
Dec 15, 2025
ee4710c
fix-mrope
Dec 15, 2025
e45189c
Merge remote-tracking branch 'origin/qwen2-vl-mrope-fix' into add-qwe…
Dec 15, 2025
b699c60
add-qwen3-vl
Dec 16, 2025
8f97e99
Merge branch 'main' into add-qwen3-vl
Dec 16, 2025
c63cae9
add-qwen3-vl
Dec 16, 2025
0e7047d
import deepstack
Dec 16, 2025
fa45ff9
add-qwen3-vl-1216
Dec 16, 2025
f5d1d60
refactor mrope
shihaobai Dec 16, 2025
ce02b13
Merge branch 'add-qwen2-vl' of https://github.com/ModelTC/lightllm in…
shihaobai Dec 16, 2025
49c949f
add-qwen3-vl1216
Dec 16, 2025
8d33f1a
fix
shihaobai Dec 16, 2025
884c227
Merge branch 'add-qwen3-vl' of https://github.com/ModelTC/lightllm in…
shihaobai Dec 16, 2025
dc2aad9
fix
shihaobai Dec 16, 2025
4519d57
add-qwen3-vl-1216
Dec 16, 2025
057bb1d
openai samping params
shihaobai Dec 16, 2025
5d1baef
Merge branch 'add-qwen3-vl' of https://github.com/ModelTC/lightllm in…
shihaobai Dec 16, 2025
b59dc5b
remove qwen2-vl resize
Dec 17, 2025
88c33c6
fix deepstack
Dec 17, 2025
d249aaf
fix cuda
Dec 17, 2025
427c5e8
fix-qwen3-vl-1217
Dec 17, 2025
eaab652
update tensor2bytes
shihaobai Dec 17, 2025
dadf600
merge
shihaobai Dec 17, 2025
16282af
fix
shihaobai Dec 17, 2025
841867d
fix
shihaobai Dec 17, 2025
1321f2e
Merge branch 'main' into add-qwen3-vl
shihaobai Dec 18, 2025
08a3484
refactor mrope
shihaobai Dec 18, 2025
a0c8bf0
qwen3 moe
shihaobai Dec 18, 2025
6df4156
refactor weight
shihaobai Dec 18, 2025
e9e5025
fix
shihaobai Dec 18, 2025
bda9b67
add embed cache one
Dec 18, 2025
042a26b
fix
hiworldwzj Dec 18, 2025
60fc7f5
fix
hiworldwzj Dec 18, 2025
2aec5a1
fix
hiworldwzj Dec 18, 2025
b298257
fix whisper
hiworldwzj Dec 18, 2025
789149f
add cpu embed to llm
hiworldwzj Dec 18, 2025
f4233b6
fix pre layer infer
hiworldwzj Dec 18, 2025
e1e2bd5
fix unittest
hiworldwzj Dec 18, 2025
5d305fb
fix non_blocking
Dec 19, 2025
a0a4c2c
add log
Dec 19, 2025
832b445
fix qwen3_vl pre_layer_infer
Dec 19, 2025
8b528a9
remove tensor2bytes byte2tensor
Dec 19, 2025
e12f8d6
remove dep code
Dec 19, 2025
a1a7d27
fix qwen3 vit out
Dec 19, 2025
8e91c9c
fix
shihaobai Dec 19, 2025
0b7ca92
fix
shihaobai Dec 19, 2025
87cd0fa
qwen3-vl cpu embed
shihaobai Dec 19, 2025
b1f1b66
fix
Dec 19, 2025
06f1ef9
fix get_mrope_position test
Dec 19, 2025
93445ae
fix qk_rms_norm
Dec 19, 2025
8317e83
fix qwenvl
Dec 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions lightllm/common/basemodel/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,27 +122,6 @@ def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"):
attr_.copy_(attr_value, non_blocking=True)
return

def mark_multimodal_objs_for_prefill(self, input_ids: torch.Tensor):
"""
功能函数,用于标记在chuncked prefill的过程中,到底哪些多模态对象对应的token是需要参与计算的。
因为分chunck的原因,并不是所有的多模态对象对应的token都需要参与计算。
"""
multi_objs = []
for _, p in enumerate(self.multimodal_params):
for obj in p["images"] + p["audios"]:
multi_objs.append(obj)

if multi_objs:
obj_start_ids = torch.tensor([e["token_id"] for e in multi_objs], dtype=torch.int64, device="cuda")
obj_token_lens = torch.tensor([e["token_num"] for e in multi_objs], dtype=torch.int64, device="cuda")
marks = mark_multimodal_obj(
obj_start_token_ids=obj_start_ids, obj_token_lens=obj_token_lens, input_ids=input_ids
)
marks_array = marks.detach().cpu().numpy()
for mark, obj in zip(marks_array, multi_objs):
obj["_prefill_"] = mark > 0
return

def prefill_dp_balance(self, input_ids: torch.Tensor):
"""
在prefill的时候, 对于处于 dp 模式下的时候,对输入的数据进行重新的调整和分配,降低各个dp处理数据量过于不一致的时候,导致
Expand Down
62 changes: 36 additions & 26 deletions lightllm/common/basemodel/triton_kernel/multimodal_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,22 @@
def _fwd_kernel(
Prompt_ids,
Text_weight_embs,
Img_embs,
Embed_cache,
Out,
Img_token_lens,
Img_start_token_ids,
Img_start_locs,
Img_start_locs_in_cache,
stride_text_emb_s,
stride_text_emb_d, # text_stride
stride_img_emb_s,
stride_img_emb_d, # img_stride
stride_emb_cache_s,
stride_emb_cache_l,
stride_emb_cache_d, # img_stride
stride_out_s,
stride_out_d,
tp_text_start_token_id,
tp_text_end_token_id,
hidden_size,
tp_world_size,
BLOCK_HIDDEN_DIM: tl.constexpr,
):

Expand All @@ -44,7 +46,7 @@ def _fwd_kernel(
tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size)

img_start_token_id = tl.load(Img_start_token_ids + img_handle_id - 1, mask=img_handle_id >= 1, other=0)
img_start_loc = tl.load(Img_start_locs + img_handle_id - 1, mask=img_handle_id >= 1, other=0)
img_start_loc = tl.load(Img_start_locs_in_cache + img_handle_id - 1, mask=img_handle_id >= 1, other=0)
img_token_len = tl.load(Img_token_lens + img_handle_id - 1, mask=img_handle_id >= 1, other=0)
# load store img emb
for _ in range(
Expand All @@ -57,11 +59,16 @@ def _fwd_kernel(
1,
):
load_emb = tl.load(
Img_embs + stride_img_emb_s * (img_start_loc + token_id - img_start_token_id) + off_d * stride_img_emb_d,
Embed_cache
+ stride_emb_cache_s.to(tl.int64) * (img_start_loc + token_id - img_start_token_id)
+ stride_emb_cache_l * 0
+ stride_emb_cache_d * off_d,
mask=off_d < hidden_size,
other=0,
)
tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size)
tl.store(
Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb / tp_world_size, mask=off_d < hidden_size
)
return


Expand All @@ -70,35 +77,38 @@ def multimodal_emb(
out: torch.Tensor,
prompt_ids: torch.Tensor,
text_weight_embs: torch.Tensor,
img_embs: torch.Tensor,
embed_cache: torch.Tensor,
img_token_lens: torch.Tensor,
img_start_token_ids: torch.Tensor,
img_start_locs: torch.Tensor,
tp_text_start_token_id,
tp_text_end_token_id,
img_start_locs_in_cache: torch.Tensor,
tp_text_start_token_id: int,
tp_text_end_token_id: int,
tp_world_size: int,
):
total_len = prompt_ids.shape[0]
BLOCK = triton.next_power_of_2(out.shape[1])
# print(len(img_token_lens))
grid = (total_len, len(img_token_lens) + 1)
num_warps = 1
_fwd_kernel[grid](
prompt_ids,
text_weight_embs,
img_embs,
out,
img_token_lens,
img_start_token_ids,
img_start_locs,
text_weight_embs.stride(0),
text_weight_embs.stride(1),
img_embs.stride(0),
img_embs.stride(1),
out.stride(0),
out.stride(1),
tp_text_start_token_id,
tp_text_end_token_id,
Prompt_ids=prompt_ids,
Text_weight_embs=text_weight_embs,
Embed_cache=embed_cache,
Out=out,
Img_token_lens=img_token_lens,
Img_start_token_ids=img_start_token_ids,
Img_start_locs_in_cache=img_start_locs_in_cache,
stride_text_emb_s=text_weight_embs.stride(0),
stride_text_emb_d=text_weight_embs.stride(1),
stride_emb_cache_s=embed_cache.stride(0),
stride_emb_cache_l=embed_cache.stride(1),
stride_emb_cache_d=embed_cache.stride(2),
stride_out_s=out.stride(0),
stride_out_d=out.stride(1),
tp_text_start_token_id=tp_text_start_token_id,
tp_text_end_token_id=tp_text_end_token_id,
hidden_size=out.shape[1],
tp_world_size=float(tp_world_size),
BLOCK_HIDDEN_DIM=BLOCK,
num_warps=num_warps,
num_stages=1,
Expand Down
2 changes: 2 additions & 0 deletions lightllm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from lightllm.models.internvl.model import InternVLInternlm2TpPartModel
from lightllm.models.qwen2_vl.model import Qwen2VLTpPartModel
from lightllm.models.qwen2_reward.model import Qwen2RewardTpPartModel
from lightllm.models.qwen3_vl.model import Qwen3VLTpPartModel
from lightllm.models.qwen3_vl_moe.model import Qwen3VLMOETpPartModel
from lightllm.models.gemma3.model import Gemma3TpPartModel
from lightllm.models.tarsier2.model import (
Tarsier2Qwen2TpPartModel,
Expand Down
61 changes: 30 additions & 31 deletions lightllm/models/gemma3/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb
from lightllm.distributed.communication_op import all_reduce
from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
from lightllm.server.embed_cache.utils import bytes2tensor, get_shm_name_embed, read_shm


class Gemma3PreLayerInfer(LlamaMultimodalPreLayerInfer):
Expand All @@ -14,16 +13,15 @@ def __init__(self, network_config, mode):
return

def context_forward(self, input_ids, infer_state, layer_weight):
img_weight = []
img_start_token_ids = []
img_token_lens = []
img_start_loc = 0
img_start_locs = []
img_start_locs_in_cache = []
device = layer_weight.wte_weight_.device
dtype = layer_weight.wte_weight_.dtype
hidden_size = layer_weight.wte_weight_.shape[1]
weight_mask = torch.zeros((len(input_ids)), dtype=torch.float32, device=device)

# TODO
scale = self.embed_scale
for idx, input_id in enumerate(input_ids):
if input_id == self.boi_token_index:
Expand All @@ -35,45 +33,46 @@ def context_forward(self, input_ids, infer_state, layer_weight):
else:
weight_mask[idx] = scale

infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids)

for batch_id, p in enumerate(infer_state.multimodal_params):
for img in p["images"]:
# skip the same image
if img["token_id"] in img_start_token_ids or img["_prefill_"] is False:
if img["token_id"] in img_start_token_ids:
continue
# pull the img_embeds by uid from shm
data = read_shm(get_shm_name_embed(img["uuid"]))
img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1))
img_start_token_ids.append(img["token_id"])
img_token_lens.append(img["token_num"])
img_start_locs.append(img_start_loc)
img_start_loc += img["token_num"]
img_start_locs_in_cache.append(img["start_index_in_embed_cache"])
out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device)
if len(img_weight) > 0:
img_weight = torch.cat(img_weight, dim=0).to(device=device, dtype=dtype)
else:
img_weight = torch.empty((0, hidden_size), device=device, dtype=dtype)
assert img_weight.shape[1] == hidden_size, (

from lightllm.server.router.model_infer.infer_batch import g_infer_context

cpu_embed_cache_tensor = g_infer_context.cpu_embed_cache_client.cpu_embed_cache_tensor

assert cpu_embed_cache_tensor.shape[2] == hidden_size, (
f"Dimension mismatch: text weight dimension is {hidden_size}, "
f"but image weight dimension is {img_weight.shape[1]}"
f"but image embed dimension is {cpu_embed_cache_tensor.shape[2]}"
)
# each tp will fill the img embeds, should divide by world_size
img_weight = img_weight / self.tp_world_size_
img_start_token_ids = torch.Tensor(img_start_token_ids).to(device=device, dtype=torch.long)
img_token_lens = torch.Tensor(img_token_lens).to(device=device, dtype=torch.long)
img_start_locs = torch.Tensor(img_start_locs).to(device=device, dtype=torch.long)
img_start_token_ids = torch.tensor(img_start_token_ids, dtype=torch.long, device="cpu", pin_memory=True).cuda(
non_blocking=True
)
img_token_lens = torch.tensor(img_token_lens, dtype=torch.long, device="cpu", pin_memory=True).cuda(
non_blocking=True
)
img_start_locs_in_cache = torch.tensor(
img_start_locs_in_cache, dtype=torch.long, device="cpu", pin_memory=True
).cuda(non_blocking=True)

multimodal_emb(
out,
input_ids,
layer_weight.wte_weight_,
img_weight,
img_token_lens,
img_start_token_ids,
img_start_locs,
self.vob_start_id_,
self.vob_end_id_,
out=out,
prompt_ids=input_ids,
text_weight_embs=layer_weight.wte_weight_,
embed_cache=cpu_embed_cache_tensor,
img_token_lens=img_token_lens,
img_start_token_ids=img_start_token_ids,
img_start_locs_in_cache=img_start_locs_in_cache,
tp_text_start_token_id=self.vob_start_id_,
tp_text_end_token_id=self.vob_end_id_,
tp_world_size=self.tp_world_size_,
)
input_dtype = out.dtype
if self.tp_world_size_ > 1:
Expand Down
6 changes: 3 additions & 3 deletions lightllm/models/qwen2_vl/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
self.position_ids = position_ids.unsqueeze(0).expand(3, -1)

self.position_ids = self.position_ids.contiguous()
self.position_cos = model._cos_cached[self.position_ids] # (3, L, D)
self.position_sin = model._sin_cached[self.position_ids] # (3, L, D)
self.position_cos = model._cos_cached[self.position_ids]
self.position_sin = model._sin_cached[self.position_ids]
if get_env_start_args().enable_fa3:
self.max_seq_len = self.max_kv_seq_len
self.q_max_seq_len = self.max_q_seq_len
Expand Down Expand Up @@ -66,7 +66,7 @@ def get_mrope_position(self, multimodal_params: List[dict]) -> torch.Tensor:
b_image_thwd = torch.tensor(b_image_thwd, device="cpu").cuda(non_blocking=True) # image_num x 4
b_image_nums = torch.tensor(b_image_nums, device="cpu").cuda(non_blocking=True)
b_image_start_num = torch.tensor(b_image_start_num, device="cpu").cuda(non_blocking=True)
b_image_len = torch.tensor(b_image_len, device=self.position_ids.device)
b_image_len = torch.tensor(b_image_len, device="cpu").cuda(non_blocking=True)
position_ids = self.position_ids.unsqueeze(0).expand(3, -1).contiguous()
get_mrope_position_triton(
b_image_start_idx=b_image_start_idx,
Expand Down
27 changes: 12 additions & 15 deletions lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,28 @@
from typing import Tuple
from functools import partial

from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton
from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer


class Qwen2VLTransformerLayerInfer(LlamaTransformerLayerInfer):
def __init__(self, layer_num, network_config, mode=[]):
super().__init__(layer_num, network_config, mode)
self.mrope_section = network_config["rope_scaling"]["mrope_section"]
axis_map = []
for i, n in enumerate(self.mrope_section * 2):
axis_map += [i % 3] * n
self.axis_map = torch.tensor(axis_map, dtype=torch.int32, device="cuda")
mrope_section = network_config["rope_scaling"]["mrope_section"]
self.mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device="cuda")

def _get_qkv(self, input, infer_state, layer_weight):
q = layer_weight.q_proj.mm(input)
cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
seq_len, _ = q.shape
q = q.view(1, seq_len, -1, self.head_dim_).transpose(1, 2)
self.axis_map = self.axis_map.to(q.device)
k = cache_kv[:, : self.tp_k_head_num_, :].view(1, seq_len, -1, self.head_dim_).transpose(1, 2)
new_q, new_k = mrope_triton(q, k, infer_state.position_cos, infer_state.position_sin, self.axis_map)
new_q = new_q.transpose(1, 2).reshape(1, seq_len, -1)
cache_kv[:, : self.tp_k_head_num_, :] = new_k.squeeze(0).permute(1, 0, 2)

return new_q, cache_kv
mrope_triton_fused(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
cache_kv[:, : self.tp_k_head_num_, :],
infer_state.position_cos,
infer_state.position_sin,
self.mrope_section,
is_interleaved=False,
)
return q, cache_kv

def _tpsp_get_qkv(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]:
# TODO
Expand Down
29 changes: 28 additions & 1 deletion lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,34 @@ def test():
b_q_seq_len,
b_start_loc,
)
print(position_ids)

# print(position_ids)
old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1)

position_ids = (
torch.tensor([2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda")
.unsqueeze(0)
.expand(3, -1)
.contiguous()
)
b_ready_cache_len = torch.tensor([2, 2], dtype=torch.int32, device="cuda")
b_q_seq_len = torch.tensor([5, 11], dtype=torch.int32, device="cuda")
b_start_loc = torch.tensor([0, 5], dtype=torch.int32, device="cuda")

get_mrope_position_triton(
b_image_start_idx,
b_image_thwd,
b_image_nums,
b_image_start_num,
b_image_len,
position_ids,
b_ready_cache_len,
b_q_seq_len,
b_start_loc,
)

assert torch.equal(old_value, position_ids)

"""
tensor([[0, 0, 0, 0, 2, 3, 4, 0, 0, 0, 0, 2, 2, 2, 2, 4, 5, 6, 7, 8],
[0, 0, 1, 1, 2, 3, 4, 0, 0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 8],
Expand Down
Loading