Skip to content

Optimize singleton MoE collectives#7997

Open
Tianyi-Franklin-Wang wants to merge 1 commit intodeepspeedai:masterfrom
Tianyi-Franklin-Wang:fix/moe-singleton-collectives
Open

Optimize singleton MoE collectives#7997
Tianyi-Franklin-Wang wants to merge 1 commit intodeepspeedai:masterfrom
Tianyi-Franklin-Wang:fix/moe-singleton-collectives

Conversation

@Tianyi-Franklin-Wang
Copy link
Copy Markdown

Summary

This PR avoids unnecessary MoE collectives when the expert-parallel group has a single rank.

The change is narrow:

  • skip the two MOELayer all-to-all calls when ep_size == 1
  • skip top-1/top-2/top-k capacity all_reduce(MAX) when the explicit expert-parallel group has world size 1
  • keep the existing collective paths unchanged for non-singleton expert-parallel groups

Motivation

For a singleton expert-parallel group, these collectives are identity operations:

  • all_to_all_single(..., group=ep_group) has no remote rank to exchange with
  • all_reduce(..., op=MAX, group=ep_group) leaves the tensor unchanged

In a downstream ep_size=1 MoE run, profiling showed repeated singleton all-to-all and capacity all-reduce calls dominating late-step time. A local version of this guarded optimization reduced late-step timing from around 13s/update to around 0.864s/update while keeping loss and MoE auxiliary loss finite.

This is related to #7141, which also reports ep_size=1 MoE all-to-all behavior.

Correctness

The fastpaths are guarded by the existing MoE runtime structure:

  • MOELayer skips _AllToAll.apply(...) only when self.ep_size == 1
  • the singleton all-to-all path still calls .contiguous(), preserving the layout normalization previously performed inside _AllToAll.forward
  • gate capacity reduction checks the runtime world size of the explicit ep_group
  • ep_group=None is not treated as a singleton expert group
  • non-singleton expert-parallel groups still use the original collectives

This does not change routing, capacity math, expert execution, combine logic, auxiliary loss, or expert counts.

Testing

  • pre-commit run --files deepspeed/moe/sharded_moe.py tests/unit/moe/test_moe.py
  • git diff --check
  • pytest --forked tests/unit/moe/test_moe.py -v -k "singleton"

The targeted pytest command selected 9 singleton tests locally, but they skipped because this local environment has no accelerator, matching the existing DistributedTest behavior.

Downstream smoke evidence:

  • 2-rank H200 run
  • top-2 MoE, drop_tokens=False
  • reached update 476 after the local fix
  • finite loss and MoE auxiliary loss
  • late-step timing improved from around 13s/update to around 0.864s/update

Signed-off-by: Tianyi Wang <npufranklin@gmail.com>
Copilot AI review requested due to automatic review settings May 7, 2026 02:28
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR optimizes DeepSpeed MoE execution for singleton expert-parallel groups by avoiding identity collectives (all-to-all and capacity all-reduce MAX) when the expert-parallel world size is 1, while preserving existing behavior for non-singleton expert-parallel groups.

Changes:

  • Skip capacity all_reduce(op=MAX) in top-1/top-2/top-k gating when ep_group exists but has world size 1.
  • Skip the two MoE all-to-all calls in MOELayer.forward when self.ep_size == 1, replacing them with .contiguous() to preserve layout normalization.
  • Add unit tests to validate that singleton expert-parallel configurations do not invoke these collectives.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
deepspeed/moe/sharded_moe.py Adds singleton fastpaths for capacity reduction and MoE all-to-all to avoid unnecessary collectives when EP is size 1.
tests/unit/moe/test_moe.py Adds tests asserting the singleton fastpaths avoid collectives and that ep_group=None does not enter EP collective logic.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +232 to +247
all_to_all_calls = []

def counted_all_to_all(group, input):
all_to_all_calls.append((group, input.shape))
return input

monkeypatch.setattr(sharded_moe._AllToAll, "apply", counted_all_to_all)

x = torch.randn(1, 4, hidden_dim, device=model.device, requires_grad=True)
output, l_aux, _ = model(x)
assert len(all_to_all_calls) == expected_calls

loss = output.float().sum() + l_aux.float()
model.backward(loss)
assert len(all_to_all_calls) == expected_calls
assert x.grad is not None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants