@@ -23,7 +23,7 @@ def moe_dispatch(
2323 lora_indices : torch .Tensor ,
2424 num_experts : int ,
2525 num_loras : int ,
26- ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
26+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch . Tensor ]:
2727 """
2828 Dispatch tokens to experts for MoE computation.
2929
@@ -38,6 +38,7 @@ def moe_dispatch(
3838 sorted_token_ids: Token indices sorted by expert_id
3939 sorted_expert_ids: Corresponding expert IDs
4040 sorted_weights: Corresponding router weights
41+ sorted_lora_ids: LoRA adapter IDs for each dispatched token
4142 """
4243 num_tokens , top_k = topk_ids .shape
4344 device = topk_ids .device
@@ -46,15 +47,17 @@ def moe_dispatch(
4647 flat_topk_ids = topk_ids .flatten ()
4748 flat_topk_weights = topk_weights .flatten ()
4849 flat_token_ids = torch .arange (num_tokens , device = device ).repeat_interleave (top_k )
50+ flat_lora_ids = lora_indices .repeat_interleave (top_k )
4951
5052 # Sort by expert_id only (each expert uses same LoRA adapter logic)
51- composite_key = flat_topk_ids
52-
53- # Sort by expert_id to group tokens by expert
54- sorted_indices = torch .argsort (composite_key )
53+ sorted_indices = torch .argsort (flat_topk_ids )
5554
5655 sorted_token_ids = flat_token_ids [sorted_indices ]
5756 sorted_expert_ids = flat_topk_ids [sorted_indices ]
5857 sorted_weights = flat_topk_weights [sorted_indices ]
5958
60- return sorted_token_ids , sorted_expert_ids , sorted_weights
59+ if flat_lora_ids .shape != sorted_indices .shape :
60+ y = 1 # need to pause
61+ sorted_lora_ids = flat_lora_ids [sorted_indices ]
62+
63+ return sorted_token_ids , sorted_expert_ids , sorted_weights , sorted_lora_ids
0 commit comments