From 001f77c363710e3f62e05c5aacbed4b2ff7c8c97 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 06:30:00 +0000 Subject: [PATCH 1/6] Initial plan From b90aee5a854d5d7b4d9e4c5c951b3c6d61a87c35 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 06:36:18 +0000 Subject: [PATCH 2/6] Revert "fix: update 1 file reformatted." This reverts commit ff886701c392ab03863c227de14fbe1d671d4173. Co-authored-by: nathon-lee <248585198+nathon-lee@users.noreply.github.com> --- deepspeed/runtime/zero/stage_1_and_2.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 107e47a44042..183fd077f8a9 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -283,11 +283,18 @@ def _enforce_cpu_offload(): self.low_precision_master_weights_and_grads = self.master_weights_and_grads_dtype != torch.float32 + # Check for Muon optimizer usage + self.uses_muon = any(getattr(param, 'use_muon', False) for group in self.optimizer.param_groups for param in group['params']) + if self.reduce_scatter and self.partition_gradients: valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32) assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled" assert self.postscale_gradients, f"pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled" + + # Check for Muon optimizer compatibility with reduce_scatter (applies to both ZeRO-1 and ZeRO-2) + if self.reduce_scatter and self.uses_muon: + assert False, f"{self.zero_stage_string} with reduce_scatter=True is incompatible with Muon optimizer. Please disable reduce_scatter or use a different optimizer." # param flattened by groups self.bit16_groups = [] @@ -1187,7 +1194,9 @@ def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dt stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - if not self.reduce_scatter: + # Check if current configuration requires full all-reduce + if not self.reduce_scatter or any(self.group_uses_muon): + # Force full all-reduce for Muon parameters or when reduce_scatter is disabled self.gradient_reduction_w_predivide(tensor, communication_data_type) return From cbc816c90f4bd6e10ab5b67f4d471002ade8cba7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 6 Mar 2026 06:40:53 +0000 Subject: [PATCH 3/6] Initial plan From 5fcc9a7e4bf58b1d935dcfeab53143d3cf9dbdf7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 6 Mar 2026 06:43:32 +0000 Subject: [PATCH 4/6] Reapply "fix: update 1 file reformatted." This reverts commit b90aee5a854d5d7b4d9e4c5c951b3c6d61a87c35. --- deepspeed/runtime/zero/stage_1_and_2.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 1efea00bcbbd..12f97348a21f 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -284,18 +284,11 @@ def _enforce_cpu_offload(): self.low_precision_master_weights_and_grads = self.master_weights_and_grads_dtype != torch.float32 - # Check for Muon optimizer usage - self.uses_muon = any(getattr(param, 'use_muon', False) for group in self.optimizer.param_groups for param in group['params']) - if self.reduce_scatter and self.partition_gradients: valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32) assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled" assert self.postscale_gradients, f"pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled" - - # Check for Muon optimizer compatibility with reduce_scatter (applies to both ZeRO-1 and ZeRO-2) - if self.reduce_scatter and self.uses_muon: - assert False, f"{self.zero_stage_string} with reduce_scatter=True is incompatible with Muon optimizer. Please disable reduce_scatter or use a different optimizer." # param flattened by groups self.bit16_groups = [] @@ -1224,9 +1217,7 @@ def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dt stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - # Check if current configuration requires full all-reduce - if not self.reduce_scatter or any(self.group_uses_muon): - # Force full all-reduce for Muon parameters or when reduce_scatter is disabled + if not self.reduce_scatter: self.gradient_reduction_w_predivide(tensor, communication_data_type) return From c5d457de665c1269a1eb3908dae995932d7c0ff1 Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Thu, 2 Apr 2026 05:34:49 +0000 Subject: [PATCH 5/6] refactor: replace bare global vars in torch_autocast with _AutocastState Signed-off-by: nathon-lee --- deepspeed/runtime/torch_autocast.py | 31 +++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/deepspeed/runtime/torch_autocast.py b/deepspeed/runtime/torch_autocast.py index 299693fdaab5..73d627254cc6 100644 --- a/deepspeed/runtime/torch_autocast.py +++ b/deepspeed/runtime/torch_autocast.py @@ -3,7 +3,8 @@ # DeepSpeed Team -from typing import Iterable, Set, List, Union +from dataclasses import dataclass, field +from typing import Iterable, Optional, Set, List, Union import importlib from contextlib import contextmanager @@ -22,9 +23,21 @@ PARAM_COMM_DTYPE_ATTR_NAME = "comm_dtype" _WARNED_NESTED_AUTOCAST = False -# TODO: Avoid using global variables -TORCH_AUTOCAST_INITIALIZED = False -TORCH_AUTOCAST_DTYPE = None + +@dataclass +class _AutocastState: + """Holds torch-autocast initialization state for the DeepSpeed engine. + + Using a single object instead of bare module-level variables avoids the + need for ``global`` statements and makes the state easier to reason about + and reset in tests. + """ + initialized: bool = False + dtype: Optional[torch.dtype] = field(default=None) + + +# Module-level singleton that stores autocast state set by init_autocast_params. +_autocast_state = _AutocastState() def _validate_auto_cast_settings(engine): @@ -56,14 +69,12 @@ def init_autocast_params(engine, dtype: torch.dtype, for p in module.parameters(recurse=False): setattr(p, PARAM_COMM_DTYPE_ATTR_NAME, dtype) - global TORCH_AUTOCAST_INITIALIZED - TORCH_AUTOCAST_INITIALIZED = True - global TORCH_AUTOCAST_DTYPE - TORCH_AUTOCAST_DTYPE = dtype + _autocast_state.initialized = True + _autocast_state.dtype = dtype def is_autocast_initialized() -> bool: - return TORCH_AUTOCAST_INITIALIZED + return _autocast_state.initialized def get_default_autocast_lower_precision_modules() -> List[str]: @@ -71,7 +82,7 @@ def get_default_autocast_lower_precision_modules() -> List[str]: def get_autocast_dtype() -> torch.dtype: - return TORCH_AUTOCAST_DTYPE + return _autocast_state.dtype def has_comm_dtype(param: torch.nn.Parameter) -> bool: From 72cadeb08b8d6b56054c48741ed87c4f0468b329 Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Wed, 8 Apr 2026 11:37:04 +0000 Subject: [PATCH 6/6] fix: store autocast state per-engine to support multiple engine configs Signed-off-by: nathon-lee --- deepspeed/runtime/base_optimizer.py | 7 ++++-- deepspeed/runtime/torch_autocast.py | 30 +++++++++++++------------ deepspeed/runtime/zero/stage3.py | 9 ++++---- deepspeed/runtime/zero/stage_1_and_2.py | 9 ++++---- 4 files changed, 31 insertions(+), 24 deletions(-) diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py index 213b5c659499..22f8cedc513a 100644 --- a/deepspeed/runtime/base_optimizer.py +++ b/deepspeed/runtime/base_optimizer.py @@ -10,7 +10,7 @@ from deepspeed.utils import logger from deepspeed.utils.tensor_fragment import map_to_flat_opt_states from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank, see_memory_usage -from deepspeed.runtime.torch_autocast import get_comm_dtype, is_autocast_initialized +from deepspeed.runtime.torch_autocast import get_comm_dtype, has_comm_dtype from deepspeed.runtime.utils import maybe_loss_for_backward @@ -354,7 +354,10 @@ def report_ipg_memory_usage(self, tag, param_elems, dtype=None): ) def get_param_comm_dtype(self, param): - if is_autocast_initialized(): + # Use the per-parameter comm_dtype attribute set by init_autocast_params(). + # Each engine stamps its own parameters, so multiple engines with different + # autocast configs are naturally isolated without a shared global state. + if has_comm_dtype(param): return get_comm_dtype(param) else: return self.communication_data_type diff --git a/deepspeed/runtime/torch_autocast.py b/deepspeed/runtime/torch_autocast.py index 73d627254cc6..7a098a428d5c 100644 --- a/deepspeed/runtime/torch_autocast.py +++ b/deepspeed/runtime/torch_autocast.py @@ -26,20 +26,16 @@ @dataclass class _AutocastState: - """Holds torch-autocast initialization state for the DeepSpeed engine. + """Holds torch-autocast initialization state for one DeepSpeed engine instance. - Using a single object instead of bare module-level variables avoids the - need for ``global`` statements and makes the state easier to reason about - and reset in tests. + Storing this object on the engine (``engine._autocast_state``) rather than as + a module-level singleton allows multiple engine instances to carry independent + autocast configurations without interfering with each other. """ initialized: bool = False dtype: Optional[torch.dtype] = field(default=None) -# Module-level singleton that stores autocast state set by init_autocast_params. -_autocast_state = _AutocastState() - - def _validate_auto_cast_settings(engine): assert not engine.zero_quantized_weights(), "Cannot enable both torch autocast and zero quantized weights" @@ -69,20 +65,26 @@ def init_autocast_params(engine, dtype: torch.dtype, for p in module.parameters(recurse=False): setattr(p, PARAM_COMM_DTYPE_ATTR_NAME, dtype) - _autocast_state.initialized = True - _autocast_state.dtype = dtype + engine._autocast_state = _AutocastState(initialized=True, dtype=dtype) -def is_autocast_initialized() -> bool: - return _autocast_state.initialized +def is_autocast_initialized(engine) -> bool: + """Return True if torch autocast was initialised for *this* engine instance. + + Accepts the engine as an argument so that multiple DeepSpeed engines can + carry independent ``_autocast_state`` objects without sharing a + module-level singleton. + """ + return getattr(engine, '_autocast_state', _AutocastState()).initialized def get_default_autocast_lower_precision_modules() -> List[str]: return [f"{cls.__module__}.{cls.__name__}" for cls in LOWER_PRECISION_SAFE_MODULES] -def get_autocast_dtype() -> torch.dtype: - return _autocast_state.dtype +def get_autocast_dtype(engine) -> torch.dtype: + """Return the autocast dtype configured for *this* engine instance.""" + return getattr(engine, '_autocast_state', _AutocastState()).dtype def has_comm_dtype(param: torch.nn.Parameter) -> bool: diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 8f28ee4f8685..96424ffd761a 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -19,7 +19,7 @@ from deepspeed.utils import logger from deepspeed.utils.torch import register_grad_hook, required_torch_version from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler -from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_comm_dtypes, is_autocast_initialized, sort_dtypes +from deepspeed.runtime.torch_autocast import get_all_comm_dtypes, has_comm_dtype, sort_dtypes from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce from deepspeed.runtime.utils import inf, is_model_parallel_parameter, mask_nan_or_inf_with_val_inplace, count_used_parameters_in_backward from deepspeed.runtime.zero.partition_parameters import * @@ -435,9 +435,10 @@ def _enforce_optimizer_offload(): self.is_param_in_current_partition = {} self.torch_autocast_gradscaler = None - if is_autocast_initialized(): - comm_dtypes = get_all_comm_dtypes([p for params in self.fp16_groups for p in params]) - if get_autocast_dtype() == torch.float16: + all_params = [p for params in self.fp16_groups for p in params] + if any(has_comm_dtype(p) for p in all_params): + comm_dtypes = get_all_comm_dtypes(all_params) + if torch.float16 in comm_dtypes: self.torch_autocast_gradscaler = torch.amp.GradScaler(device=get_accelerator().device_name()) else: comm_dtypes = {self.communication_data_type} diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index f3a0352bebfa..2a7bd684706e 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -19,7 +19,7 @@ from deepspeed.runtime.zero.offload_states import offload_optimizer_states, reload_optimizer_states from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler -from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_comm_dtypes, is_autocast_initialized, sort_dtypes +from deepspeed.runtime.torch_autocast import get_all_comm_dtypes, has_comm_dtype, sort_dtypes from deepspeed.runtime.utils import (empty_cache, see_memory_usage, inf, is_model_parallel_parameter, align_dense_tensors, all_gather_dp_groups, mask_nan_or_inf_with_val_inplace, count_used_parameters_in_backward) @@ -518,9 +518,10 @@ def _enforce_cpu_offload(): self.is_param_in_current_partition = {} self.torch_autocast_gradscaler = None - if is_autocast_initialized(): - comm_dtypes = get_all_comm_dtypes([p for params in self.bit16_groups for p in params]) - if get_autocast_dtype() == torch.float16: + all_params = [p for params in self.bit16_groups for p in params] + if any(has_comm_dtype(p) for p in all_params): + comm_dtypes = get_all_comm_dtypes(all_params) + if torch.float16 in comm_dtypes: self.torch_autocast_gradscaler = torch.amp.GradScaler(device=get_accelerator().device_name()) else: comm_dtypes = {self.communication_data_type}