Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2542,6 +2542,103 @@ def no_sync(self):
finally:
self.inside_no_sync_ctxt = False

@contextmanager
def coalesce_grad_reduction(self):
r"""Coalesce ZeRO 1/2/3 gradient reduction across multiple engine.backward()
calls. One with-block == one optimizer step: every backward inside
leaves grads locally on params, and the flush on exit issues a single
reduction pass that populates averaged_gradients for the next step().

Constraints:
- engine.step() inside the block raises.
- Reentry / nesting with engine.no_sync() raises.
- Do not span multiple gradient_accumulation_steps with multiple
with-blocks; the flush overwrites averaged_gradients each exit.

Unsupported (NotImplementedError): ZeRO stage 0, BF16/FP16_Optimizer
wrappers, PipelineModule.
"""
stage = self.zero_optimization_stage()
if stage not in (ZeroStageEnum.optimizer_states, ZeroStageEnum.gradients, ZeroStageEnum.weights):
raise NotImplementedError(f"coalesce_grad_reduction requires ZeRO stage 1/2/3, got stage {int(stage)}")
if self.pipeline_parallelism:
raise NotImplementedError("coalesce_grad_reduction is not supported under pipeline parallelism")
optimizer = self.optimizer
if not hasattr(optimizer, "_coalesce_grad_reduction"):
# BF16_Optimizer / FP16_Optimizer route grads through their own
# backward_epilogue path, bypassing DeepSpeedZeroOptimizer's
# per-param hooks that this context relies on.
raise NotImplementedError(
f"coalesce_grad_reduction does not yet support optimizer wrapper {type(optimizer).__name__}")
assert not self.inside_no_sync_ctxt, \
"coalesce_grad_reduction cannot be nested inside another no_sync context"

# Engine boundary is the source of truth; optimizer's copy is overwritten
# by _backward_prologue from the engine value on each backward, so we
# only need to save/restore the engine flag.
saved_engine_boundary = self._is_gradient_accumulation_boundary
self.inside_no_sync_ctxt = True
optimizer._coalesce_grad_reduction = True
try:
yield
finally:
# Reset _coalesce_grad_reduction BEFORE the flush so the reducer calls
# we drive in the flush helpers do NOT short-circuit at our guard
# in process_gradients / reduce_ready_partitions_and_remove_grads.
optimizer._coalesce_grad_reduction = False
self.inside_no_sync_ctxt = False
self._is_gradient_accumulation_boundary = True
optimizer.is_gradient_accumulation_boundary = True
try:
# Drive a single reduction pass over locally accumulated grads.
# Iterate explicitly (rather than calling reduce_gradients) so
# the path works regardless of overlap_comm / contiguous_gradients,
# both of which alter reduce_gradients's control flow.
if stage == ZeroStageEnum.weights:
self._flush_coalesced_reduction_zero3(optimizer)
else:
self._flush_coalesced_reduction_zero12(optimizer)
finally:
self._is_gradient_accumulation_boundary = saved_engine_boundary

def _flush_coalesced_reduction_zero12(self, optimizer):
# Quiesce the reduction stream before re-entering it (overlap_comm uses
# a separate stream + double-buffered ipg bucket). Without this the
# bucket.index swap in reduce_independent_p_g_buckets_and_remove_grads
# may race against the previous step's residual reduction.
if getattr(optimizer, "overlap_comm", False) and hasattr(optimizer, "reduction_stream"):
if not get_accelerator().resolves_data_dependency():
optimizer.reduction_stream.synchronize()
# Ensure ipg bucket buffers exist (process_gradients normally allocates
# them via setup_buckets, but we suppressed it during coalesce period).
# Note: micro_step_id increments by 1 here for the whole coalesce block,
# which is fine -- copy_grads_in_partition's accumulate condition uses
# micro_step_id > 0 OR not boundary, and we force boundary=True.
optimizer.setup_buckets()
for i, group in enumerate(optimizer.bit16_groups):
for param in group:
if not param.requires_grad:
continue
# use_grad_accum_attribute=True parks the accumulated grad in
# param.grad_accum instead of param.grad (backward_epilogue
# routes it there each microbatch). get_gradient_for_reduction
# returns the right one for both modes.
if optimizer.get_gradient_for_reduction(param) is None:
continue
optimizer.reduce_ready_partitions_and_remove_grads(param, i)
optimizer.overlapping_partition_gradients_reduce_epilogue()

def _flush_coalesced_reduction_zero3(self, optimizer):
# Leaf-module unused-param zero-fill (stage3.py:1336-1337) runs from
# the leaf module's own backward hook, BEFORE the reducer call we
# suppress. So by flush time the leaf params already have grads (real
# or zero-filled) populated by the hook regardless of _coalesce_grad_reduction.
for group in optimizer.fp16_groups:
for param in group:
if param.requires_grad and param.grad is not None:
optimizer.reduce_ready_partitions_and_remove_grads(param)
optimizer.independent_gradient_partition_epilogue()

def scale(self, loss):
r"""Apply loss scaler for manual backward pass.

Expand Down
5 changes: 5 additions & 0 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,9 @@ def _enforce_optimizer_offload():

self.is_gradient_accumulation_boundary: bool = True

# Toggled by DeepSpeedEngine.coalesce_grad_reduction().
self._coalesce_grad_reduction = False

self.param_reduce_events: Deque[get_accelerator().Event] = collections.deque()
# TODO. make this configurable via JSON
self.max_param_reduce_events: int = 2
Expand Down Expand Up @@ -1811,6 +1814,8 @@ def _partitioned_buffers_all_gather(self, params: List[Parameter], buffers_to_al
return output

def reduce_ready_partitions_and_remove_grads(self, param):
if self._coalesce_grad_reduction:
return
#print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True)
self.reduce_independent_p_g_buckets_and_remove_grads(param)

Expand Down
5 changes: 5 additions & 0 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ def __init__(self,

self.is_gradient_accumulation_boundary = True

# Toggled by DeepSpeedEngine.coalesce_grad_reduction().
self._coalesce_grad_reduction = False

# CPU-Offload requires contiguous gradients
self.contiguous_gradients = contiguous_gradients or self.cpu_offload

Expand Down Expand Up @@ -1612,6 +1615,8 @@ def reduce_ipg_grads(self, comm_dtype=None):
#####################################################################

def process_gradients(self, param, i):
if self._coalesce_grad_reduction:
return
self.setup_buckets()
if self.use_grad_accum_attribute:
self._fill_param_grad_accum_attribute(param)
Expand Down
Loading
Loading