Skip to content

Commit 367612a

Browse files
committed
fix lora id issue
1 parent 7fa7ddd commit 367612a

File tree

4 files changed

+619
-123
lines changed

4 files changed

+619
-123
lines changed

python/sglang/srt/layers/moe/lora_moe.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,17 +104,14 @@ def _compute_lora_delta(
104104
num_loras = self.lora_a_weights.shape[0]
105105

106106
# Dispatch tokens to experts
107-
token_ids, expert_ids, _ = moe_dispatch(
107+
token_ids, expert_ids, _, lora_ids = moe_dispatch(
108108
topk_ids=topk_ids,
109109
topk_weights=topk_weights,
110110
lora_indices=lora_indices,
111111
num_experts=num_experts,
112112
num_loras=num_loras,
113113
)
114114

115-
# Get LoRA IDs for dispatched tokens
116-
lora_ids = lora_indices[token_ids]
117-
118115

119116

120117
# Compute per-expert LoRA forward (adds to base_output in-place)

python/sglang/srt/lora/lora_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
285285

286286
# Populate per-token LoRA indices from segment information
287287
batch_info = self.lora_backend.batch_info
288-
num_tokens = forward_batch.batch_size
288+
num_tokens = forward_batch.seq_lens_sum # Total tokens across all sequences
289289
if batch_info.permutation is None:
290290
# No reordering (e.g., triton backend): segments are in original order
291291
token_lora_indices = torch.empty(num_tokens, dtype=torch.int32, device=batch_info.weight_indices.device)

python/sglang/srt/lora/moe_dispatch.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def moe_dispatch(
2323
lora_indices: torch.Tensor,
2424
num_experts: int,
2525
num_loras: int,
26-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
26+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
2727
"""
2828
Dispatch tokens to experts for MoE computation.
2929
@@ -38,6 +38,7 @@ def moe_dispatch(
3838
sorted_token_ids: Token indices sorted by expert_id
3939
sorted_expert_ids: Corresponding expert IDs
4040
sorted_weights: Corresponding router weights
41+
sorted_lora_ids: LoRA adapter IDs for each dispatched token
4142
"""
4243
num_tokens, top_k = topk_ids.shape
4344
device = topk_ids.device
@@ -46,15 +47,17 @@ def moe_dispatch(
4647
flat_topk_ids = topk_ids.flatten()
4748
flat_topk_weights = topk_weights.flatten()
4849
flat_token_ids = torch.arange(num_tokens, device=device).repeat_interleave(top_k)
50+
flat_lora_ids = lora_indices.repeat_interleave(top_k)
4951

5052
# Sort by expert_id only (each expert uses same LoRA adapter logic)
51-
composite_key = flat_topk_ids
52-
53-
# Sort by expert_id to group tokens by expert
54-
sorted_indices = torch.argsort(composite_key)
53+
sorted_indices = torch.argsort(flat_topk_ids)
5554

5655
sorted_token_ids = flat_token_ids[sorted_indices]
5756
sorted_expert_ids = flat_topk_ids[sorted_indices]
5857
sorted_weights = flat_topk_weights[sorted_indices]
5958

60-
return sorted_token_ids, sorted_expert_ids, sorted_weights
59+
if flat_lora_ids.shape != sorted_indices.shape:
60+
y = 1 # need to pause
61+
sorted_lora_ids = flat_lora_ids[sorted_indices]
62+
63+
return sorted_token_ids, sorted_expert_ids, sorted_weights, sorted_lora_ids

0 commit comments

Comments
 (0)