diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py index 213b5c659499..22f8cedc513a 100644 --- a/deepspeed/runtime/base_optimizer.py +++ b/deepspeed/runtime/base_optimizer.py @@ -10,7 +10,7 @@ from deepspeed.utils import logger from deepspeed.utils.tensor_fragment import map_to_flat_opt_states from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank, see_memory_usage -from deepspeed.runtime.torch_autocast import get_comm_dtype, is_autocast_initialized +from deepspeed.runtime.torch_autocast import get_comm_dtype, has_comm_dtype from deepspeed.runtime.utils import maybe_loss_for_backward @@ -354,7 +354,10 @@ def report_ipg_memory_usage(self, tag, param_elems, dtype=None): ) def get_param_comm_dtype(self, param): - if is_autocast_initialized(): + # Use the per-parameter comm_dtype attribute set by init_autocast_params(). + # Each engine stamps its own parameters, so multiple engines with different + # autocast configs are naturally isolated without a shared global state. + if has_comm_dtype(param): return get_comm_dtype(param) else: return self.communication_data_type diff --git a/deepspeed/runtime/torch_autocast.py b/deepspeed/runtime/torch_autocast.py index 299693fdaab5..7a098a428d5c 100644 --- a/deepspeed/runtime/torch_autocast.py +++ b/deepspeed/runtime/torch_autocast.py @@ -3,7 +3,8 @@ # DeepSpeed Team -from typing import Iterable, Set, List, Union +from dataclasses import dataclass, field +from typing import Iterable, Optional, Set, List, Union import importlib from contextlib import contextmanager @@ -22,9 +23,17 @@ PARAM_COMM_DTYPE_ATTR_NAME = "comm_dtype" _WARNED_NESTED_AUTOCAST = False -# TODO: Avoid using global variables -TORCH_AUTOCAST_INITIALIZED = False -TORCH_AUTOCAST_DTYPE = None + +@dataclass +class _AutocastState: + """Holds torch-autocast initialization state for one DeepSpeed engine instance. + + Storing this object on the engine (``engine._autocast_state``) rather than as + a module-level singleton allows multiple engine instances to carry independent + autocast configurations without interfering with each other. + """ + initialized: bool = False + dtype: Optional[torch.dtype] = field(default=None) def _validate_auto_cast_settings(engine): @@ -56,22 +65,26 @@ def init_autocast_params(engine, dtype: torch.dtype, for p in module.parameters(recurse=False): setattr(p, PARAM_COMM_DTYPE_ATTR_NAME, dtype) - global TORCH_AUTOCAST_INITIALIZED - TORCH_AUTOCAST_INITIALIZED = True - global TORCH_AUTOCAST_DTYPE - TORCH_AUTOCAST_DTYPE = dtype + engine._autocast_state = _AutocastState(initialized=True, dtype=dtype) -def is_autocast_initialized() -> bool: - return TORCH_AUTOCAST_INITIALIZED +def is_autocast_initialized(engine) -> bool: + """Return True if torch autocast was initialised for *this* engine instance. + + Accepts the engine as an argument so that multiple DeepSpeed engines can + carry independent ``_autocast_state`` objects without sharing a + module-level singleton. + """ + return getattr(engine, '_autocast_state', _AutocastState()).initialized def get_default_autocast_lower_precision_modules() -> List[str]: return [f"{cls.__module__}.{cls.__name__}" for cls in LOWER_PRECISION_SAFE_MODULES] -def get_autocast_dtype() -> torch.dtype: - return TORCH_AUTOCAST_DTYPE +def get_autocast_dtype(engine) -> torch.dtype: + """Return the autocast dtype configured for *this* engine instance.""" + return getattr(engine, '_autocast_state', _AutocastState()).dtype def has_comm_dtype(param: torch.nn.Parameter) -> bool: diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 8f28ee4f8685..96424ffd761a 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -19,7 +19,7 @@ from deepspeed.utils import logger from deepspeed.utils.torch import register_grad_hook, required_torch_version from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler -from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_comm_dtypes, is_autocast_initialized, sort_dtypes +from deepspeed.runtime.torch_autocast import get_all_comm_dtypes, has_comm_dtype, sort_dtypes from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce from deepspeed.runtime.utils import inf, is_model_parallel_parameter, mask_nan_or_inf_with_val_inplace, count_used_parameters_in_backward from deepspeed.runtime.zero.partition_parameters import * @@ -435,9 +435,10 @@ def _enforce_optimizer_offload(): self.is_param_in_current_partition = {} self.torch_autocast_gradscaler = None - if is_autocast_initialized(): - comm_dtypes = get_all_comm_dtypes([p for params in self.fp16_groups for p in params]) - if get_autocast_dtype() == torch.float16: + all_params = [p for params in self.fp16_groups for p in params] + if any(has_comm_dtype(p) for p in all_params): + comm_dtypes = get_all_comm_dtypes(all_params) + if torch.float16 in comm_dtypes: self.torch_autocast_gradscaler = torch.amp.GradScaler(device=get_accelerator().device_name()) else: comm_dtypes = {self.communication_data_type} diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index f3a0352bebfa..2a7bd684706e 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -19,7 +19,7 @@ from deepspeed.runtime.zero.offload_states import offload_optimizer_states, reload_optimizer_states from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler -from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_comm_dtypes, is_autocast_initialized, sort_dtypes +from deepspeed.runtime.torch_autocast import get_all_comm_dtypes, has_comm_dtype, sort_dtypes from deepspeed.runtime.utils import (empty_cache, see_memory_usage, inf, is_model_parallel_parameter, align_dense_tensors, all_gather_dp_groups, mask_nan_or_inf_with_val_inplace, count_used_parameters_in_backward) @@ -518,9 +518,10 @@ def _enforce_cpu_offload(): self.is_param_in_current_partition = {} self.torch_autocast_gradscaler = None - if is_autocast_initialized(): - comm_dtypes = get_all_comm_dtypes([p for params in self.bit16_groups for p in params]) - if get_autocast_dtype() == torch.float16: + all_params = [p for params in self.bit16_groups for p in params] + if any(has_comm_dtype(p) for p in all_params): + comm_dtypes = get_all_comm_dtypes(all_params) + if torch.float16 in comm_dtypes: self.torch_autocast_gradscaler = torch.amp.GradScaler(device=get_accelerator().device_name()) else: comm_dtypes = {self.communication_data_type}