Skip to content

Commit d3b27ee

Browse files
committed
rename vars for clarity
1 parent af3c758 commit d3b27ee

File tree

3 files changed

+7
-17
lines changed

3 files changed

+7
-17
lines changed

python/sglang/srt/layers/moe/lora_moe.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,10 @@ def _compute_lora_delta(
104104
num_loras = self.lora_a_weights.shape[0]
105105

106106
# Dispatch tokens to experts
107-
token_ids, expert_ids, _, lora_ids = moe_dispatch(
107+
token_ids, expert_ids, sorted_topk_weights, lora_ids = moe_dispatch(
108108
topk_ids=topk_ids,
109109
topk_weights=topk_weights,
110110
lora_indices=lora_indices,
111-
num_experts=num_experts,
112-
num_loras=num_loras,
113111
)
114112

115113

python/sglang/srt/lora/moe_dispatch.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ def moe_dispatch(
2121
topk_ids: torch.Tensor,
2222
topk_weights: torch.Tensor,
2323
lora_indices: torch.Tensor,
24-
num_experts: int,
25-
num_loras: int,
2624
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
2725
"""
2826
Dispatch tokens to experts for MoE computation.
@@ -31,13 +29,11 @@ def moe_dispatch(
3129
topk_ids: [num_tokens, top_k] - Expert IDs selected by router
3230
topk_weights: [num_tokens, top_k] - Router weights
3331
lora_indices: [num_tokens] - LoRA adapter ID for each token
34-
num_experts: Total number of experts
35-
num_loras: Total number of LoRA adapters
3632
3733
Returns:
3834
sorted_token_ids: Token indices sorted by expert_id
3935
sorted_expert_ids: Corresponding expert IDs
40-
sorted_weights: Corresponding router weights
36+
sorted_topk_weights: Corresponding router weights
4137
sorted_lora_ids: LoRA adapter IDs for each dispatched token
4238
"""
4339
num_tokens, top_k = topk_ids.shape
@@ -54,10 +50,7 @@ def moe_dispatch(
5450

5551
sorted_token_ids = flat_token_ids[sorted_indices]
5652
sorted_expert_ids = flat_topk_ids[sorted_indices]
57-
sorted_weights = flat_topk_weights[sorted_indices]
58-
59-
if flat_lora_ids.shape != sorted_indices.shape:
60-
y = 1 # need to pause
53+
sorted_topk_weights = flat_topk_weights[sorted_indices]
6154
sorted_lora_ids = flat_lora_ids[sorted_indices]
6255

63-
return sorted_token_ids, sorted_expert_ids, sorted_weights, sorted_lora_ids
56+
return sorted_token_ids, sorted_expert_ids, sorted_topk_weights, sorted_lora_ids

test/srt/lora/test_lora_moe.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -592,19 +592,18 @@ def test_moe_lora_basic_functionality(self):
592592
lora_indices = torch.tensor([0, 0, 1, 1], dtype=torch.int32) # tokens 0,1 use lora 0; tokens 2,3 use lora 1
593593

594594
# Run dispatch
595-
token_ids, expert_ids, weights = moe_dispatch(
595+
token_ids, expert_ids, sorted_topk_weights, lora_ids = moe_dispatch(
596596
topk_ids=topk_ids,
597597
topk_weights=topk_weights,
598598
lora_indices=lora_indices,
599-
num_experts=num_experts,
600-
num_loras=2,
601599
)
602600

603601
# Verify results
604602
# Should have 4 tokens * 2 experts each = 8 dispatched entries
605603
self.assertEqual(len(token_ids), 8)
606604
self.assertEqual(len(expert_ids), 8)
607-
self.assertEqual(len(weights), 8)
605+
self.assertEqual(len(sorted_topk_weights), 8)
606+
self.assertEqual(len(lora_ids), 8)
608607

609608
# Check that tokens are grouped by expert (not by LoRA)
610609
# All tokens going to expert 0 should come first, then expert 1, etc.

0 commit comments

Comments
 (0)