Optimize singleton MoE collectives#7997
Open
Tianyi-Franklin-Wang wants to merge 1 commit intodeepspeedai:masterfrom
Open
Optimize singleton MoE collectives#7997Tianyi-Franklin-Wang wants to merge 1 commit intodeepspeedai:masterfrom
Tianyi-Franklin-Wang wants to merge 1 commit intodeepspeedai:masterfrom
Conversation
Signed-off-by: Tianyi Wang <npufranklin@gmail.com>
There was a problem hiding this comment.
Pull request overview
This PR optimizes DeepSpeed MoE execution for singleton expert-parallel groups by avoiding identity collectives (all-to-all and capacity all-reduce MAX) when the expert-parallel world size is 1, while preserving existing behavior for non-singleton expert-parallel groups.
Changes:
- Skip capacity
all_reduce(op=MAX)in top-1/top-2/top-k gating whenep_groupexists but has world size 1. - Skip the two MoE all-to-all calls in
MOELayer.forwardwhenself.ep_size == 1, replacing them with.contiguous()to preserve layout normalization. - Add unit tests to validate that singleton expert-parallel configurations do not invoke these collectives.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
deepspeed/moe/sharded_moe.py |
Adds singleton fastpaths for capacity reduction and MoE all-to-all to avoid unnecessary collectives when EP is size 1. |
tests/unit/moe/test_moe.py |
Adds tests asserting the singleton fastpaths avoid collectives and that ep_group=None does not enter EP collective logic. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+232
to
+247
| all_to_all_calls = [] | ||
|
|
||
| def counted_all_to_all(group, input): | ||
| all_to_all_calls.append((group, input.shape)) | ||
| return input | ||
|
|
||
| monkeypatch.setattr(sharded_moe._AllToAll, "apply", counted_all_to_all) | ||
|
|
||
| x = torch.randn(1, 4, hidden_dim, device=model.device, requires_grad=True) | ||
| output, l_aux, _ = model(x) | ||
| assert len(all_to_all_calls) == expected_calls | ||
|
|
||
| loss = output.float().sum() + l_aux.float() | ||
| model.backward(loss) | ||
| assert len(all_to_all_calls) == expected_calls | ||
| assert x.grad is not None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR avoids unnecessary MoE collectives when the expert-parallel group has a single rank.
The change is narrow:
MOELayerall-to-all calls whenep_size == 1all_reduce(MAX)when the explicit expert-parallel group has world size 1Motivation
For a singleton expert-parallel group, these collectives are identity operations:
all_to_all_single(..., group=ep_group)has no remote rank to exchange withall_reduce(..., op=MAX, group=ep_group)leaves the tensor unchangedIn a downstream
ep_size=1MoE run, profiling showed repeated singleton all-to-all and capacity all-reduce calls dominating late-step time. A local version of this guarded optimization reduced late-step timing from around 13s/update to around 0.864s/update while keeping loss and MoE auxiliary loss finite.This is related to #7141, which also reports
ep_size=1MoE all-to-all behavior.Correctness
The fastpaths are guarded by the existing MoE runtime structure:
MOELayerskips_AllToAll.apply(...)only whenself.ep_size == 1.contiguous(), preserving the layout normalization previously performed inside_AllToAll.forwardep_groupep_group=Noneis not treated as a singleton expert groupThis does not change routing, capacity math, expert execution, combine logic, auxiliary loss, or expert counts.
Testing
pre-commit run --files deepspeed/moe/sharded_moe.py tests/unit/moe/test_moe.pygit diff --checkpytest --forked tests/unit/moe/test_moe.py -v -k "singleton"The targeted pytest command selected 9 singleton tests locally, but they skipped because this local environment has no accelerator, matching the existing
DistributedTestbehavior.Downstream smoke evidence:
drop_tokens=False