Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added baseline_pp.py
Binary file not shown.
8 changes: 7 additions & 1 deletion deepspeed/comm/ccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@

import torch
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import NotImplementedBuilder
try:
from deepspeed.ops.op_builder import NotImplementedBuilder
except ModuleNotFoundError:
# Fallback for environments where deepspeed/ops symlinks are missing.
# This still allows DeepSpeed to run because CCL backend gracefully
# disables itself when op builders are unavailable.
from op_builder import NotImplementedBuilder
from .reduce_op import ReduceOp
from .torch import TorchBackend

Expand Down
159 changes: 159 additions & 0 deletions deepspeed/runtime/comm/mori.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

"""mori SDMA backend for the ZeRO-3 all_gather_into_tensor hot path.

Encapsulates every mori-specific import, handle construction and dtype
dispatch so ``deepspeed/runtime/zero/partition_parameters.py`` only needs
to call:

mori.init(max_numel) # one-shot, idempotent
work = mori.allgather_into_tensor(in_, out_) # returns None on fallback

The backend silently fails (no exceptions, ``init`` leaves the handle
unset, ``allgather_into_tensor`` returns ``None``) when mori is missing,
the platform isn't AMD/ROCm, or shmem initialization fails. Callers are
expected to fall back to ``dist.allgather_fn`` in that case.
"""

from typing import Optional

import torch

from deepspeed import comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.utils import logger

_handle = None
_dtype_map = None
_init_attempted = False
_call_failed_warned = False


class _SdmaWork:
"""Duck-type compatible with ``torch.distributed.Work``.

Mirrors NCCL ``Work.wait()`` semantics: CPU-level blocking AND
GPU-level stream dependency so the current compute stream sees
SDMA-written data.
"""

def __init__(self, event):
self._event = event

def wait(self):
# Stream-level dependency only; do NOT block CPU. RCCL Work.wait()
# is also non-CPU-blocking and the ZeRO-3 prefetch pipeline depends on
# the CPU staying free so the next bucket can be queued ahead of time.
get_accelerator().current_stream().wait_event(self._event)

def is_completed(self) -> bool:
return self._event.query()


def _ensure_default_pg_registered():
"""Register the WORLD process group as 'default' in PyTorch's C++ GroupRegistry.

mori's shmem layer looks up the PG by name "default"; the standard
DeepSpeed init path doesn't register it under that label.
"""
world_group = torch.distributed.group.WORLD
assert world_group is not None, "torch.distributed must be initialized before SDMA allgather"
torch._C._distributed_c10d._register_process_group("default", world_group)


def _build_dtype_map():
"""torch.dtype -> mori_cpp.DataType (NCCL-style enum)."""
from mori.ccl import DataType
return {
torch.uint8: DataType.Uint8,
torch.int8: DataType.Int8,
torch.int16: DataType.Int16,
torch.int32: DataType.Int32,
torch.int64: DataType.Int64,
torch.float16: DataType.Float16,
torch.bfloat16: DataType.BFloat16,
torch.float32: DataType.Float32,
torch.float64: DataType.Float64,
}


def init(max_numel: int = 64 * 1024 * 1024) -> None:
"""Best-effort, idempotent SDMA handle construction.

Builds one ``mori_cpp.AllGatherIntoTensor`` (NCCL/RCCL-style C++
dispatcher) sized for the largest expected per-rank shard. All
subsequent allgather calls reuse this handle.

Safe to call unconditionally: any failure (mori not installed,
non-AMD/ROCm runtime, shmem init error, ...) leaves ``_handle``
unset and logs a single rank-0 warning, so callers transparently
fall back to RCCL/NCCL via ``dist.allgather_fn``.
"""
global _handle, _dtype_map, _init_attempted
if _init_attempted:
return
_init_attempted = True

try:
_ensure_default_pg_registered()
import mori.shmem as shmem
from mori.ccl import AllGatherIntoTensor

shmem.shmem_torch_process_group_init("default")
my_pe = shmem.shmem_mype()
npes = shmem.shmem_npes()
# Per-rank input transit buffer must hold the largest shard we'll
# ever see; output transit buffer = npes * input. 4 B/element is
# the SDMA kernel's uint32 lane width.
input_bytes = max_numel * 4
_handle = AllGatherIntoTensor(
my_pe=my_pe,
npes=npes,
input_buffer_size=input_bytes,
output_buffer_size=input_bytes * npes,
copy_output_to_user=True,
)
_dtype_map = _build_dtype_map()
if dist.is_initialized() and dist.get_rank() == 0:
logger.info("SDMA allgather enabled via mori_cpp.AllGatherIntoTensor")
except Exception as e:
_handle = None
_dtype_map = None
if dist.is_initialized() and dist.get_rank() == 0:
logger.warning(f"SDMA allgather unavailable ({type(e).__name__}: {e}); "
f"falling back to dist.allgather_fn")


def is_enabled() -> bool:
return _handle is not None


def allgather_into_tensor(input_tensor: torch.Tensor,
output_tensor: torch.Tensor) -> Optional[_SdmaWork]:
"""Run one allgather_into_tensor through the SDMA handle.

Returns an ``_SdmaWork`` (Work-compatible) on success. Returns
``None`` if SDMA is disabled or the call fails for any reason — the
caller should then fall back to ``dist.allgather_fn``.
"""
global _call_failed_warned
if _handle is None:
return None
try:
stream = get_accelerator().current_stream()
dtype = _dtype_map[input_tensor.dtype]
ok = _handle(input_tensor.data_ptr(), output_tensor.data_ptr(),
input_tensor.numel(), dtype, stream.cuda_stream)
if not ok:
return None
event = get_accelerator().Event()
event.record(stream)
return _SdmaWork(event)
except Exception as e:
if not _call_failed_warned and dist.is_initialized() and dist.get_rank() == 0:
logger.warning(f"SDMA allgather failed ({e}); falling back to dist.allgather")
_call_failed_warned = True
return None
15 changes: 15 additions & 0 deletions deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
"zeropp_loco_param": {...},
"log_trace_cache_warnings" : [true|false],
"enable_sanity_checks": [true|false],
"sdma_allgather": [true|false],
"sdma_allgather_max_numel": 67108864,
}
}
"""
Expand Down Expand Up @@ -371,6 +373,19 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
Configuration for modules that should be treated as ZeRO3 leaf modules.
"""

sdma_allgather: bool = False
"""
Use mori SDMA allgather instead of RCCL allgather for ZeRO-3 parameter
fetching. Effective only when ``overlap_comm`` is enabled (stage 3).
Requires the ``mori`` package (``mori.ccl.AllgatherSdma``).
"""

sdma_allgather_max_numel: int = Field(pp_int(64 * 1024 * 1024), ge=0)
"""
Maximum number of elements (uint32) per allgather call when using SDMA.
Controls the pre-allocated transit buffer size inside ``AllgatherSdma``.
"""

# Validators
@model_validator(mode="after")
def overlap_comm_valid(self):
Expand Down
54 changes: 30 additions & 24 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks, is_zero_param
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.runtime.config_utils import get_config_default
from deepspeed.runtime.comm import mori
from deepspeed.utils import instrument_w_nvtx, logger
from deepspeed.comm.comm import init_distributed
from deepspeed.utils.debug import (debug_param2name_id_shape, debug_param2name_id_shape_device, debug_module2name,
Expand Down Expand Up @@ -106,6 +107,9 @@ def wait(self, **kwargs) -> None:


def _dist_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group=None):
work = mori.allgather_into_tensor(input_tensor, output_tensor)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Honor ZeRO's process group before using SDMA

When ZeRO is initialized with a non-WORLD data-parallel group, or with a secondary zero-param group, _all_gather passes that group down as ds_process_group (partition_parameters.py:1463-1471), but this new SDMA call ignores the group argument and uses mori's WORLD-backed default process group. In those model/tensor-parallel configurations mori gathers from more ranks than the caller allocated output_tensor for, which can corrupt fetched parameters or write past the expected buffer; fall back unless group is WORLD or make mori initialize/use the matching group.

Useful? React with 👍 / 👎.

if work is not None:
return work
return instrument_w_nvtx(dist.allgather_fn)(output_tensor, input_tensor, group=group, async_op=True)


Expand Down Expand Up @@ -684,7 +688,12 @@ def restore_init_context():

class AllGatherHandle:

def __init__(self, handle, param: Parameter, quantization=None, param_buffer=None, original_dtype=None) -> None:
def __init__(self,
handle,
param: Parameter,
quantization=None,
param_buffer=None,
original_dtype=None) -> None:
if param.ds_status != ZeroParamStatus.INFLIGHT:
raise RuntimeError(f"expected param {param.ds_summary()} to be available")

Expand All @@ -698,12 +707,14 @@ def wait(self, handle_dependency=True) -> None:
instrument_w_nvtx(self.__handle.wait)()

if self.__param_buffer is not None:
self.__param.data = self.__param_buffer.narrow(0, 0, self.__param.ds_numel).view(self.__param.ds_shape).to(
gathered = self.__param_buffer.narrow(0, 0, self.__param.ds_numel).view(self.__param.ds_shape).to(
self.__original_dtype).to(self.__param.device)
self.__param.data = gathered
elif self.__quantization:
instrument_w_nvtx(self.__quantization.quant_handle.wait)()
self.__param.data = self.__quantization.backend.dequantize(
self.__quantization.quantized_param, self.__quantization.scale_buffer).to(self.__param.device)

self.__param.ds_status = ZeroParamStatus.AVAILABLE


Expand Down Expand Up @@ -1112,6 +1123,12 @@ def __init__(self,
self.use_all_reduce_for_fetch_params = _ds_config.zero_config.use_all_reduce_for_fetch_params
self.allgather_sequential = _ds_config.zero_config.allgather_sequential

if _ds_config.zero_config.sdma_allgather:
cfg_max = _ds_config.zero_config.sdma_allgather_max_numel
prefetch_partition = int(_ds_config.zero_config.prefetch_bucket_size) // self.num_partitions
safe_max = max(cfg_max, prefetch_partition * 2)
mori.init(max_numel=safe_max)

def _update_persist_config(self, ds_config):
Init.apply_param_persistence = True
Init.param_persistence_threshold = ds_config.zero_config.param_persistence_threshold
Expand Down Expand Up @@ -1242,7 +1259,8 @@ def _all_gather_dtype(params, world_size, rank_in_group, ds_process_group, allga
if use_secondary_tensor:
partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params)

flat_tensor = torch.empty(partition_sz * world_size,
total_numel = partition_sz * world_size
flat_tensor = torch.empty(total_numel,
dtype=allgather_dtype,
device=get_accelerator().current_device_name(),
requires_grad=False)
Expand All @@ -1261,8 +1279,8 @@ def _all_gather_dtype(params, world_size, rank_in_group, ds_process_group, allga
instrument_w_nvtx(torch.cat)(
[p.ds_tensor.to(get_accelerator().current_device_name()).to(allgather_dtype) for p in params],
out=partitions[rank_in_group])
handle = _dist_allgather_fn(partitions[rank_in_group], flat_tensor, ds_process_group)
#Fix get_partition_dp_group(params[0]))
handle = instrument_w_nvtx(dist.allgather_fn)(
flat_tensor, partitions[rank_in_group], group=ds_process_group, async_op=True)
Comment on lines +1282 to +1283
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Route coalesced allgathers through the SDMA wrapper

With the default stage3_allgather_sequential=false, any ZeRO-3 fetch containing more than one parameter takes _all_gather_dtype, but this path now calls dist.allgather_fn directly instead of _dist_allgather_fn. As a result, enabling sdma_allgather has no effect for the common coalesced prefetch path (including the added sample config, which does not enable sequential allgather), so the advertised optimization is skipped for most multi-parameter buckets.

Useful? React with 👍 / 👎.


return AllGatherCoalescedHandle(
allgather_handle=handle,
Expand Down Expand Up @@ -1300,25 +1318,13 @@ def _all_gather_sequential(params, world_size, use_secondary_tensor, ds_process_
ds_process_group,
)

if original_dtype == allgather_dtype:
param.data = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device)
handles.append(AllGatherHandle(handle, param))
else:
# This case is complicated:
# We use `register_post_accumulate_grad_hook` to set allgather hooks. Normally, the hook is
# called once per parameter, even if that parameter is tied to multiple layers.
# However, when the dtype changes, the hook may be triggered multiple times.
# If we directly do:
# param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device)
# as above, the dtype may differ, causing the gradient-reduce hook
# to be invoked multiple times.
# To avoid this, we leave `param.data` in a partitioned state.
# This prevents duplicate gradient-reduce hook calls.
# In theory, this path could be consolidated with the case where
# (original_dtype == allgather_dtype), but because it changes the
# state transition of DeepSpeed parameters, we keep it separate for safety.
handles.append(
AllGatherHandle(handle, param, param_buffer=param_buffer, original_dtype=original_dtype))
handles.append(
AllGatherHandle(
handle,
param,
param_buffer=param_buffer,
original_dtype=original_dtype,
))
else:
if hasattr(param_ds_tensor, "ds_quant_scale"):
scales = param_ds_tensor.ds_quant_scale
Expand Down
Loading
Loading