-
Notifications
You must be signed in to change notification settings - Fork 4.8k
zero3: SDMA allgather via mori (sdma_allgather) #7999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
55b24f3
fbedb2f
ccb634e
e0eb510
6512ecf
d5f8489
4b2d44d
f3a0d1b
939cc0c
ca01795
33edc8a
5eb18e8
f7d587d
4053ea1
72020df
6b782d9
fc41552
2c5104c
f979a54
5644ae3
2f5eaa6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # DeepSpeed Team | ||
|
|
||
| """mori SDMA backend for the ZeRO-3 all_gather_into_tensor hot path. | ||
|
|
||
| Encapsulates every mori-specific import, handle construction and dtype | ||
| dispatch so ``deepspeed/runtime/zero/partition_parameters.py`` only needs | ||
| to call: | ||
|
|
||
| mori.init(max_numel) # one-shot, idempotent | ||
| work = mori.allgather_into_tensor(in_, out_) # returns None on fallback | ||
|
|
||
| The backend silently fails (no exceptions, ``init`` leaves the handle | ||
| unset, ``allgather_into_tensor`` returns ``None``) when mori is missing, | ||
| the platform isn't AMD/ROCm, or shmem initialization fails. Callers are | ||
| expected to fall back to ``dist.allgather_fn`` in that case. | ||
| """ | ||
|
|
||
| from typing import Optional | ||
|
|
||
| import torch | ||
|
|
||
| from deepspeed import comm as dist | ||
| from deepspeed.accelerator import get_accelerator | ||
| from deepspeed.utils import logger | ||
|
|
||
| _handle = None | ||
| _dtype_map = None | ||
| _init_attempted = False | ||
| _call_failed_warned = False | ||
|
|
||
|
|
||
| class _SdmaWork: | ||
| """Duck-type compatible with ``torch.distributed.Work``. | ||
|
|
||
| Mirrors NCCL ``Work.wait()`` semantics: CPU-level blocking AND | ||
| GPU-level stream dependency so the current compute stream sees | ||
| SDMA-written data. | ||
| """ | ||
|
|
||
| def __init__(self, event): | ||
| self._event = event | ||
|
|
||
| def wait(self): | ||
| # Stream-level dependency only; do NOT block CPU. RCCL Work.wait() | ||
| # is also non-CPU-blocking and the ZeRO-3 prefetch pipeline depends on | ||
| # the CPU staying free so the next bucket can be queued ahead of time. | ||
| get_accelerator().current_stream().wait_event(self._event) | ||
|
|
||
| def is_completed(self) -> bool: | ||
| return self._event.query() | ||
|
|
||
|
|
||
| def _ensure_default_pg_registered(): | ||
| """Register the WORLD process group as 'default' in PyTorch's C++ GroupRegistry. | ||
|
|
||
| mori's shmem layer looks up the PG by name "default"; the standard | ||
| DeepSpeed init path doesn't register it under that label. | ||
| """ | ||
| world_group = torch.distributed.group.WORLD | ||
| assert world_group is not None, "torch.distributed must be initialized before SDMA allgather" | ||
| torch._C._distributed_c10d._register_process_group("default", world_group) | ||
|
|
||
|
|
||
| def _build_dtype_map(): | ||
| """torch.dtype -> mori_cpp.DataType (NCCL-style enum).""" | ||
| from mori.ccl import DataType | ||
| return { | ||
| torch.uint8: DataType.Uint8, | ||
| torch.int8: DataType.Int8, | ||
| torch.int16: DataType.Int16, | ||
| torch.int32: DataType.Int32, | ||
| torch.int64: DataType.Int64, | ||
| torch.float16: DataType.Float16, | ||
| torch.bfloat16: DataType.BFloat16, | ||
| torch.float32: DataType.Float32, | ||
| torch.float64: DataType.Float64, | ||
| } | ||
|
|
||
|
|
||
| def init(max_numel: int = 64 * 1024 * 1024) -> None: | ||
| """Best-effort, idempotent SDMA handle construction. | ||
|
|
||
| Builds one ``mori_cpp.AllGatherIntoTensor`` (NCCL/RCCL-style C++ | ||
| dispatcher) sized for the largest expected per-rank shard. All | ||
| subsequent allgather calls reuse this handle. | ||
|
|
||
| Safe to call unconditionally: any failure (mori not installed, | ||
| non-AMD/ROCm runtime, shmem init error, ...) leaves ``_handle`` | ||
| unset and logs a single rank-0 warning, so callers transparently | ||
| fall back to RCCL/NCCL via ``dist.allgather_fn``. | ||
| """ | ||
| global _handle, _dtype_map, _init_attempted | ||
| if _init_attempted: | ||
| return | ||
| _init_attempted = True | ||
|
|
||
| try: | ||
| _ensure_default_pg_registered() | ||
| import mori.shmem as shmem | ||
| from mori.ccl import AllGatherIntoTensor | ||
|
|
||
| shmem.shmem_torch_process_group_init("default") | ||
| my_pe = shmem.shmem_mype() | ||
| npes = shmem.shmem_npes() | ||
| # Per-rank input transit buffer must hold the largest shard we'll | ||
| # ever see; output transit buffer = npes * input. 4 B/element is | ||
| # the SDMA kernel's uint32 lane width. | ||
| input_bytes = max_numel * 4 | ||
| _handle = AllGatherIntoTensor( | ||
| my_pe=my_pe, | ||
| npes=npes, | ||
| input_buffer_size=input_bytes, | ||
| output_buffer_size=input_bytes * npes, | ||
| copy_output_to_user=True, | ||
| ) | ||
| _dtype_map = _build_dtype_map() | ||
| if dist.is_initialized() and dist.get_rank() == 0: | ||
| logger.info("SDMA allgather enabled via mori_cpp.AllGatherIntoTensor") | ||
| except Exception as e: | ||
| _handle = None | ||
| _dtype_map = None | ||
| if dist.is_initialized() and dist.get_rank() == 0: | ||
| logger.warning(f"SDMA allgather unavailable ({type(e).__name__}: {e}); " | ||
| f"falling back to dist.allgather_fn") | ||
|
|
||
|
|
||
| def is_enabled() -> bool: | ||
| return _handle is not None | ||
|
|
||
|
|
||
| def allgather_into_tensor(input_tensor: torch.Tensor, | ||
| output_tensor: torch.Tensor) -> Optional[_SdmaWork]: | ||
| """Run one allgather_into_tensor through the SDMA handle. | ||
|
|
||
| Returns an ``_SdmaWork`` (Work-compatible) on success. Returns | ||
| ``None`` if SDMA is disabled or the call fails for any reason — the | ||
| caller should then fall back to ``dist.allgather_fn``. | ||
| """ | ||
| global _call_failed_warned | ||
| if _handle is None: | ||
| return None | ||
| try: | ||
| stream = get_accelerator().current_stream() | ||
| dtype = _dtype_map[input_tensor.dtype] | ||
| ok = _handle(input_tensor.data_ptr(), output_tensor.data_ptr(), | ||
| input_tensor.numel(), dtype, stream.cuda_stream) | ||
| if not ok: | ||
| return None | ||
| event = get_accelerator().Event() | ||
| event.record(stream) | ||
| return _SdmaWork(event) | ||
| except Exception as e: | ||
| if not _call_failed_warned and dist.is_initialized() and dist.get_rank() == 0: | ||
| logger.warning(f"SDMA allgather failed ({e}); falling back to dist.allgather") | ||
| _call_failed_warned = True | ||
| return None |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,7 @@ | |
| from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks, is_zero_param | ||
| from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum | ||
| from deepspeed.runtime.config_utils import get_config_default | ||
| from deepspeed.runtime.comm import mori | ||
| from deepspeed.utils import instrument_w_nvtx, logger | ||
| from deepspeed.comm.comm import init_distributed | ||
| from deepspeed.utils.debug import (debug_param2name_id_shape, debug_param2name_id_shape_device, debug_module2name, | ||
|
|
@@ -106,6 +107,9 @@ def wait(self, **kwargs) -> None: | |
|
|
||
|
|
||
| def _dist_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group=None): | ||
| work = mori.allgather_into_tensor(input_tensor, output_tensor) | ||
| if work is not None: | ||
| return work | ||
| return instrument_w_nvtx(dist.allgather_fn)(output_tensor, input_tensor, group=group, async_op=True) | ||
|
|
||
|
|
||
|
|
@@ -684,7 +688,12 @@ def restore_init_context(): | |
|
|
||
| class AllGatherHandle: | ||
|
|
||
| def __init__(self, handle, param: Parameter, quantization=None, param_buffer=None, original_dtype=None) -> None: | ||
| def __init__(self, | ||
| handle, | ||
| param: Parameter, | ||
| quantization=None, | ||
| param_buffer=None, | ||
| original_dtype=None) -> None: | ||
| if param.ds_status != ZeroParamStatus.INFLIGHT: | ||
| raise RuntimeError(f"expected param {param.ds_summary()} to be available") | ||
|
|
||
|
|
@@ -698,12 +707,14 @@ def wait(self, handle_dependency=True) -> None: | |
| instrument_w_nvtx(self.__handle.wait)() | ||
|
|
||
| if self.__param_buffer is not None: | ||
| self.__param.data = self.__param_buffer.narrow(0, 0, self.__param.ds_numel).view(self.__param.ds_shape).to( | ||
| gathered = self.__param_buffer.narrow(0, 0, self.__param.ds_numel).view(self.__param.ds_shape).to( | ||
| self.__original_dtype).to(self.__param.device) | ||
| self.__param.data = gathered | ||
| elif self.__quantization: | ||
| instrument_w_nvtx(self.__quantization.quant_handle.wait)() | ||
| self.__param.data = self.__quantization.backend.dequantize( | ||
| self.__quantization.quantized_param, self.__quantization.scale_buffer).to(self.__param.device) | ||
|
|
||
| self.__param.ds_status = ZeroParamStatus.AVAILABLE | ||
|
|
||
|
|
||
|
|
@@ -1112,6 +1123,12 @@ def __init__(self, | |
| self.use_all_reduce_for_fetch_params = _ds_config.zero_config.use_all_reduce_for_fetch_params | ||
| self.allgather_sequential = _ds_config.zero_config.allgather_sequential | ||
|
|
||
| if _ds_config.zero_config.sdma_allgather: | ||
| cfg_max = _ds_config.zero_config.sdma_allgather_max_numel | ||
| prefetch_partition = int(_ds_config.zero_config.prefetch_bucket_size) // self.num_partitions | ||
| safe_max = max(cfg_max, prefetch_partition * 2) | ||
| mori.init(max_numel=safe_max) | ||
|
|
||
| def _update_persist_config(self, ds_config): | ||
| Init.apply_param_persistence = True | ||
| Init.param_persistence_threshold = ds_config.zero_config.param_persistence_threshold | ||
|
|
@@ -1242,7 +1259,8 @@ def _all_gather_dtype(params, world_size, rank_in_group, ds_process_group, allga | |
| if use_secondary_tensor: | ||
| partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params) | ||
|
|
||
| flat_tensor = torch.empty(partition_sz * world_size, | ||
| total_numel = partition_sz * world_size | ||
| flat_tensor = torch.empty(total_numel, | ||
| dtype=allgather_dtype, | ||
| device=get_accelerator().current_device_name(), | ||
| requires_grad=False) | ||
|
|
@@ -1261,8 +1279,8 @@ def _all_gather_dtype(params, world_size, rank_in_group, ds_process_group, allga | |
| instrument_w_nvtx(torch.cat)( | ||
| [p.ds_tensor.to(get_accelerator().current_device_name()).to(allgather_dtype) for p in params], | ||
| out=partitions[rank_in_group]) | ||
| handle = _dist_allgather_fn(partitions[rank_in_group], flat_tensor, ds_process_group) | ||
| #Fix get_partition_dp_group(params[0])) | ||
| handle = instrument_w_nvtx(dist.allgather_fn)( | ||
| flat_tensor, partitions[rank_in_group], group=ds_process_group, async_op=True) | ||
|
Comment on lines
+1282
to
+1283
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
With the default Useful? React with 👍 / 👎. |
||
|
|
||
| return AllGatherCoalescedHandle( | ||
| allgather_handle=handle, | ||
|
|
@@ -1300,25 +1318,13 @@ def _all_gather_sequential(params, world_size, use_secondary_tensor, ds_process_ | |
| ds_process_group, | ||
| ) | ||
|
|
||
| if original_dtype == allgather_dtype: | ||
| param.data = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device) | ||
| handles.append(AllGatherHandle(handle, param)) | ||
| else: | ||
| # This case is complicated: | ||
| # We use `register_post_accumulate_grad_hook` to set allgather hooks. Normally, the hook is | ||
| # called once per parameter, even if that parameter is tied to multiple layers. | ||
| # However, when the dtype changes, the hook may be triggered multiple times. | ||
| # If we directly do: | ||
| # param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device) | ||
| # as above, the dtype may differ, causing the gradient-reduce hook | ||
| # to be invoked multiple times. | ||
| # To avoid this, we leave `param.data` in a partitioned state. | ||
| # This prevents duplicate gradient-reduce hook calls. | ||
| # In theory, this path could be consolidated with the case where | ||
| # (original_dtype == allgather_dtype), but because it changes the | ||
| # state transition of DeepSpeed parameters, we keep it separate for safety. | ||
| handles.append( | ||
| AllGatherHandle(handle, param, param_buffer=param_buffer, original_dtype=original_dtype)) | ||
| handles.append( | ||
| AllGatherHandle( | ||
| handle, | ||
| param, | ||
| param_buffer=param_buffer, | ||
| original_dtype=original_dtype, | ||
| )) | ||
| else: | ||
| if hasattr(param_ds_tensor, "ds_quant_scale"): | ||
| scales = param_ds_tensor.ds_quant_scale | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When ZeRO is initialized with a non-WORLD data-parallel group, or with a secondary zero-param group,
_all_gatherpasses that group down asds_process_group(partition_parameters.py:1463-1471), but this new SDMA call ignores thegroupargument and uses mori's WORLD-backed default process group. In those model/tensor-parallel configurations mori gathers from more ranks than the caller allocatedoutput_tensorfor, which can corrupt fetched parameters or write past the expected buffer; fall back unlessgroupis WORLD or make mori initialize/use the matching group.Useful? React with 👍 / 👎.