Skip to content

zero3: SDMA allgather via mori (sdma_allgather) #7999

Open
inkcherry wants to merge 21 commits intodeepspeedai:masterfrom
inkcherry:sdma_ag_
Open

zero3: SDMA allgather via mori (sdma_allgather) #7999
inkcherry wants to merge 21 commits intodeepspeedai:masterfrom
inkcherry:sdma_ag_

Conversation

@inkcherry
Copy link
Copy Markdown
Contributor

@inkcherry inkcherry commented May 7, 2026

Summary

RFC: #7884

Wire sdma_allgather into ZeRO-3's parameter prefetch path
(_dist_allgather_fn). When enabled, ZeRO-3 allgather routes through
mori_cpp.AllGatherIntoTensor (intra-node SDMA copy on AMD MI300), with a
transparent fallback to dist.allgather_fn (RCCL/NCCL) on init failure.

End-to-end demo + repro steps + verified numbers live in
examples/sdma_allgather/README.md.

Headline (8x MI300X, DeepSpeed default ZeRO-3 buckets, 100 steps):

GPT-7B-ish Qwen3-32B
SDMA off 697.7 ms / step 1402.5 ms / step
SDMA on 622.0 ms / step 1263.2 ms / step
gain +10.85 % +9.93 %

Loss curves match off ↔ on, peak memory unchanged.

Speedup is workload-dependent — gains shrink (or invert) when allgather can't be overlapped with compute

Co-authored-by: wuyl1 yangwu@amd.com

wuyl1 and others added 20 commits March 6, 2026 14:46
Move all mori-specific code (handle, dtype map, transit buffer sizing,
PG-registration helper, Work wrapper) out of partition_parameters.py
into a dedicated runtime/comm backend module:

  deepspeed/runtime/comm/mori.py
    mori.init(max_numel)               # one-shot, idempotent, exception-safe
    mori.is_enabled()                  # cheap predicate
    mori.allgather_into_tensor(in, out)
        -> Work-compatible object on success, None on fallback

The new backend uses mori_cpp.AllGatherIntoTensor (NCCL/RCCL-style
flat->flat C++ dispatcher) instead of the old mori.ccl.AllgatherSdma
templated Python wrapper, so DeepSpeed no longer has to pre-convert
numel into uint32 lane counts or template the C++ class on dtype.

partition_parameters.py is now agnostic to the SDMA path:

    def _dist_allgather_fn(input_tensor, output_tensor, group=None):
        work = mori.allgather_into_tensor(input_tensor, output_tensor)
        if work is not None:
            return work
        return instrument_w_nvtx(dist.allgather_fn)(...)

Init failure (mori missing, non-AMD/ROCm runtime, shmem init error)
leaves the handle unset and logs a single rank-0 warning, so the SDMA
path silently no-ops and dist.allgather_fn (RCCL/NCCL) takes over —
no hard fail.

Net change: partition_parameters.py shrinks by 79 lines; one new
self-contained module under runtime/comm/.
_SdmaWork.wait() previously blocked the CPU on _event.synchronize()
before issuing the stream-level dependency.  RCCL's Work.wait() only
records a stream-level wait (cudaStreamWaitEvent / hipStreamWaitEvent)
and does NOT block the CPU, which is what the ZeRO-3 prefetch pipeline
relies on: while bucket N is in flight on the GPU, the CPU is free to
queue bucket N+1 so it can overlap with the trailing compute of N.

The CPU-blocking variant turned out to be a per-step critical-path tax
that wiped out SDMA's headroom on workloads that issue many small
allgathers per step.  Concretely, on Qwen3-32B + ZeRO-3 + seq_len=128,
8x MI300X, ~6400 prefetch buckets per step:

  before:  SDMA  1014 ms / step  (1009 tok/s)
           RCCL   932 ms / step  (1099 tok/s)   -> SDMA -8.0%
  after:   SDMA   927 ms / step  (1104 tok/s)
           RCCL   929 ms / step  (1100 tok/s)   -> within noise

Loss curve is bit-identical with and without the CPU sync, so this is
purely a CPU-pipelining fix.  is_completed() is unchanged (it polls
via _event.query() without blocking, same as before).
Move the zero3_overlap demo dir into examples/sdma_allgather/ (the name
that matches the feature being demoed) and add a Qwen3-32B + ZeRO-3
trainer that reproduces the +9.93% end-to-end speedup of this PR on
8x MI300X with the default DeepSpeed bucket sizes.

Layout:
  ds_config_zero3_{sdma,nosdma}.json   ZeRO-3 + bf16 + DS-default buckets
  run_gpt_sdma_{on,off}.sh             GPT-7B-ish demo (existing trainer)
  run_qwen3_sdma_{on,off}.sh           Qwen3-32B demo (new trainer)
  train_qwen3_zero3.py                 self-contained Qwen3 trainer
  README.md                            feature overview + repro steps
  train_zero3.py                       unchanged (renamed only)
  test_sdma_allgather_zero3.py         unchanged (renamed only)

train_qwen3_zero3.py inlines a minimal wikitext-103 dataloader so the
benchmark has no dependency on external benchmark repos.  Loading via
AutoConfig + from_config keeps the example weight-free; only the model
config and tokenizer are pulled from HuggingFace.

The configs use DeepSpeed's default ZeRO-3 bucket sizes
(stage3_prefetch_bucket_size = 5e7, etc.) so the published numbers
in README.md are reproducible without any tuning.

Verified on 8x MI300X, two fresh rounds:
  Qwen3-32B + ZeRO-3 + DP=8, seq_len=1024, micro_bs=1, 100 steps
    SDMA off : 1402.5 ms / step   (5841 tok/s)
    SDMA on  : 1263.2 ms / step   (6486 tok/s)   -> +9.93% e2e

  GPT-7B + ZeRO-3 + DP=8, 100 steps                -> +5.9% e2e

Loss curves match across the two backends, peak memory is identical
(96.45 GB), per-step jitter is 1.4-2.7%, so the ~140 ms gap is well
above noise.

Drops:
  examples/zero3_overlap/run.sh                   superseded by run_gpt_*
  examples/zero3_overlap/ds_config_zero3.json     superseded by *_sdma.json
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 66673546b5

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".



def _dist_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group=None):
work = mori.allgather_into_tensor(input_tensor, output_tensor)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Honor ZeRO's process group before using SDMA

When ZeRO is initialized with a non-WORLD data-parallel group, or with a secondary zero-param group, _all_gather passes that group down as ds_process_group (partition_parameters.py:1463-1471), but this new SDMA call ignores the group argument and uses mori's WORLD-backed default process group. In those model/tensor-parallel configurations mori gathers from more ranks than the caller allocated output_tensor for, which can corrupt fetched parameters or write past the expected buffer; fall back unless group is WORLD or make mori initialize/use the matching group.

Useful? React with 👍 / 👎.

Comment on lines +1282 to +1283
handle = instrument_w_nvtx(dist.allgather_fn)(
flat_tensor, partitions[rank_in_group], group=ds_process_group, async_op=True)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Route coalesced allgathers through the SDMA wrapper

With the default stage3_allgather_sequential=false, any ZeRO-3 fetch containing more than one parameter takes _all_gather_dtype, but this path now calls dist.allgather_fn directly instead of _dist_allgather_fn. As a result, enabling sdma_allgather has no effect for the common coalesced prefetch path (including the added sample config, which does not enable sequential allgather), so the advertised optimization is skipped for most multi-parameter buckets.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant