diff --git a/AGENTS.md b/AGENTS.md index 70ed661a3c35..15a2d15e77a5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -7,7 +7,7 @@ - All commits MUST have a `Signed-off-by` line (use `--signoff`). Get the name and email from `git config user.name` / `git config user.email`. - Formatting: yapf (column_limit=119, `.style.yapf`) + flake8 (`.flake8`). -- Always verify changed files pass pre-commit checks before committing. Config: `.pre-commit-config.yaml`. +- Always verify changed files pass pre-commit checks before committing: `pre-commit run --files `. Only check modified files, not the entire codebase. Config: `.pre-commit-config.yaml`. - `check-torchdist` hook: NEVER directly import torch's distributed module. Use `import deepspeed.comm as dist` instead. - New files require license header: ``` diff --git a/CLAUDE.md b/CLAUDE.md index 70ed661a3c35..15a2d15e77a5 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -7,7 +7,7 @@ - All commits MUST have a `Signed-off-by` line (use `--signoff`). Get the name and email from `git config user.name` / `git config user.email`. - Formatting: yapf (column_limit=119, `.style.yapf`) + flake8 (`.flake8`). -- Always verify changed files pass pre-commit checks before committing. Config: `.pre-commit-config.yaml`. +- Always verify changed files pass pre-commit checks before committing: `pre-commit run --files `. Only check modified files, not the entire codebase. Config: `.pre-commit-config.yaml`. - `check-torchdist` hook: NEVER directly import torch's distributed module. Use `import deepspeed.comm as dist` instead. - New files require license header: ``` diff --git a/README.md b/README.md index b7d4eaffda0e..73b2b701a74b 100755 --- a/README.md +++ b/README.md @@ -16,6 +16,8 @@ ## Latest News +* [2026/03] [Our SuperOffload work received an Honorable Mention for the ASPLOS 2026 Best Paper Award](https://dl.acm.org/doi/10.1145/3760250.3762217) + * [2025/12] [DeepSpeed Core API updates: PyTorch-style backward and low-precision master states](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/core_api_update/README.md) * [2025/11] [DeepSpeed ZeRO++ powers large-scale distillation training of LLMs for Recommendation Systems at LinkedIn](https://aclanthology.org/2025.emnlp-industry.119/) diff --git a/deepspeed/ops/adam/cpu_adam.py b/deepspeed/ops/adam/cpu_adam.py index d0974497b4c3..a71b815fa20c 100755 --- a/deepspeed/ops/adam/cpu_adam.py +++ b/deepspeed/ops/adam/cpu_adam.py @@ -236,14 +236,12 @@ def rollback_subgroup(self, sub_group_id: int, closure=None): f"CPUAdam param is on {param.device} and must be 'cpu', " f"make sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config.") - # Decrement step count - subgroup_state['step'] -= 1 - - # Extract hyperparameters beta1, beta2 = group['betas'] self.ds_opt_adam.adam_rollback(self.opt_id, subgroup_state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], group['bias_correction'], param.data, param.grad.data, subgroup_state['exp_avg'], subgroup_state['exp_avg_sq']) + + subgroup_state['step'] -= 1 return loss diff --git a/deepspeed/ops/transformer/inference/triton/matmul_ext.py b/deepspeed/ops/transformer/inference/triton/matmul_ext.py index 633997b61a5f..f8d11ce6fdf3 100644 --- a/deepspeed/ops/transformer/inference/triton/matmul_ext.py +++ b/deepspeed/ops/transformer/inference/triton/matmul_ext.py @@ -25,11 +25,19 @@ def is_nfs_path(path): # Normalize the path to get the absolute path path = os.path.abspath(path) + # Walk up to the nearest existing ancestor so 'df' does not fail + # when the target directory has not been created yet (see #7642). + while not os.path.exists(path): + parent = os.path.dirname(path) + if parent == path: + break + path = parent + # Use the 'df' command to find the file system type for the given path try: - output = subprocess.check_output(['df', '-T', path], encoding='utf-8') - except subprocess.CalledProcessError: - return False # Command failed + output = subprocess.check_output(['df', '-T', path], encoding='utf-8', stderr=subprocess.DEVNULL) + except (subprocess.CalledProcessError, FileNotFoundError): + return False # Command failed or 'df' not available # Process the output of 'df -T' to check for 'nfs' in the filesystem type column lines = output.strip().split('\n') diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index bfbb19e68696..aa9deaf81ad6 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1101,6 +1101,9 @@ def zero_legacy_stage1(self): def zero_ignore_unused_parameters(self): return self._config.zero_config.ignore_unused_parameters + def zero_save_muon_momentum_buffer_in_memory(self): + return self._config.zero_config.save_muon_momentum_buffer_in_memory + def tensor_parallel_config(self): return self._config.tensor_parallel_config @@ -1733,7 +1736,6 @@ def _configure_basic_optimizer(self, model_parameters): optimizer = MuSGD(model_parameters, **optimizer_parameters) elif self.optimizer_name() == MUON_OPTIMIZER: zero_stage = self.zero_optimization_stage() - assert zero_stage <= ZeroStageEnum.gradients, "Muon optimizer is not yet compatible with ZeRO Stage 3" if not all([hasattr(p, 'use_muon') for p in model_parameters]): msg = "Muon optimizer is used, but the use_muon attribute is NOT configured for some of the model parameters, " \ "please set by `param.use_muon = True / False` for all params" @@ -2045,6 +2047,7 @@ def _configure_zero_optimizer(self, optimizer): log_trace_cache_warnings=self.zero_log_trace_cache_warnings(), enable_sanity_checks=self.is_sanity_checks_enabled(), cpuadam_cores_perc=self.cpuadam_cores_perc(), + save_muon_momentum_buffer_in_memory=self.zero_save_muon_momentum_buffer_in_memory(), ) else: diff --git a/deepspeed/runtime/superoffload/superoffload_stage3.py b/deepspeed/runtime/superoffload/superoffload_stage3.py index 90b9bf297358..7c496a3dda37 100644 --- a/deepspeed/runtime/superoffload/superoffload_stage3.py +++ b/deepspeed/runtime/superoffload/superoffload_stage3.py @@ -5,7 +5,6 @@ import time import torch -from collections import deque from typing import List from deepspeed.runtime.superoffload.superoffload_utils import SuperOffloadCPUOptimizer, TaskKeys, ResultKeys, EventTypes @@ -18,6 +17,13 @@ OPTIMIZER_STEP_TIMER = 'optimizer_step' +def _validate_superoffload_accelerator(): + """Validate that the current accelerator is compatible with SuperOffload.""" + accelerator = get_accelerator() + assert accelerator.device_name() == 'cuda', ( + f"SuperOffload only supports NVIDIA CUDA GPUs, but found accelerator '{accelerator.device_name()}'.") + + class SuperOffloadOptimizer_Stage3(DeepSpeedZeroOptimizer_Stage3): def __init__( @@ -29,24 +35,26 @@ def __init__( ds_config, **kwargs, ): + _validate_superoffload_accelerator() self.sub_group_to_param_num = {} - self.params_in_ipg_bucket_buffer = deque() - self._cur_bucket_index = -1 + self.sub_group_grad_partition_counts = {} self.async_cpuadam_num = 0 self.max_grad_numel = 0 super().__init__(module, init_optimizer, param_names, timers, ds_config, **kwargs) - optimizer_config = { - "lr": self.optimizer.param_groups[0]["lr"], - "betas": self.optimizer.param_groups[0]["betas"], - "eps": self.optimizer.param_groups[0]["eps"], - "weight_decay": self.optimizer.param_groups[0]["weight_decay"], - "amsgrad": self.optimizer.param_groups[0]["amsgrad"] - } + optimizer_configs = [] + for pg in self.optimizer.param_groups: + optimizer_configs.append({ + "lr": pg["lr"], + "betas": pg["betas"], + "eps": pg["eps"], + "weight_decay": pg["weight_decay"], + "amsgrad": pg["amsgrad"], + }) cpuadam_cores_perc = kwargs.get("cpuadam_cores_perc", 0.8) - self.superoffload_cpu_optimizer = SuperOffloadCPUOptimizer(optimizer_config=optimizer_config, + self.superoffload_cpu_optimizer = SuperOffloadCPUOptimizer(optimizer_config=optimizer_configs, cpuadam_cores_perc=cpuadam_cores_perc, max_grad_numel=self.max_grad_numel) @@ -56,6 +64,9 @@ def _create_fp16_sub_groups(self, params_group): sub_group_size = self.sub_group_size if sub_group_size is None or sub_group_size >= params_group_numel: + global_idx = len(self.sub_group_to_param_num) + self.sub_group_to_param_num[global_idx] = len(params_group) + self.max_grad_numel = max(self.max_grad_numel, params_group_numel) return [params_group] sub_groups = [] @@ -69,7 +80,8 @@ def _create_fp16_sub_groups(self, params_group): if local_sub_group_size >= sub_group_size or id(param) == id(params_group[-1]): self.max_grad_numel = max(self.max_grad_numel, local_sub_group_size) sub_groups.append(sub_group) - self.sub_group_to_param_num[len(sub_groups) - 1] = len(sub_group) + global_idx = len(self.sub_group_to_param_num) + self.sub_group_to_param_num[global_idx] = len(sub_group) sub_group = [] local_sub_group_size = 0 @@ -93,43 +105,16 @@ def step_with_gradscaler(optimizer): step_with_gradscaler(self.backup_optimizer) self.backup_optimizer.param_groups[param_group_id]['params'] = [] - def reduce_independent_p_g_buckets_and_remove_grads(self, param): - comm_dtype = self.get_param_comm_dtype(param) - bucket = self.ipg_buckets[comm_dtype] - i, _, _ = self.grad_position[self.get_param_id(param)] - - if len(bucket.params) == 0: - self._cur_bucket_index = i - if getattr(param, "ds_grad_is_ready", True): - self._DeepSpeedZeroOptimizer_Stage3__add_grad_to_ipg_bucket(param) - - # If this is a single-parameter sub-group, reduce immediately - if self.sub_group_to_param_num[self._cur_bucket_index] == 1: - self._DeepSpeedZeroOptimizer_Stage3__reduce_and_partition_ipg_grads(comm_dtype) - - elif i != self._cur_bucket_index: - # Parameter belongs to different sub-group, buffer it - self.params_in_ipg_bucket_buffer.append(param) - else: - # Parameter belongs to current bucket - if getattr(param, "ds_grad_is_ready", True): - self._DeepSpeedZeroOptimizer_Stage3__add_grad_to_ipg_bucket(param) - - # Check if bucket is complete - if self.sub_group_to_param_num[self._cur_bucket_index] == len(bucket.params): - self._DeepSpeedZeroOptimizer_Stage3__reduce_and_partition_ipg_grads(comm_dtype) - - # Process buffered parameters - while self.params_in_ipg_bucket_buffer: - buffered_param = self.params_in_ipg_bucket_buffer.popleft() - ci, _, _ = self.grad_position[self.get_param_id(buffered_param)] - self._cur_bucket_index = ci - if getattr(buffered_param, "ds_grad_is_ready", True): - self._DeepSpeedZeroOptimizer_Stage3__add_grad_to_ipg_bucket(buffered_param) + @instrument_w_nvtx + def independent_gradient_partition_epilogue(self): + super().independent_gradient_partition_epilogue() + self.sub_group_grad_partition_counts.clear() @instrument_w_nvtx def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id): if self.subgroup_to_device[sub_group_id] == 'cpu': + self.fp16_partitioned_groups_flat[sub_group_id].data.copy_( + self.fp32_partitioned_groups_flat[sub_group_id].data) self._unflatten_partitioned_parameters(sub_group_id) return @@ -147,62 +132,61 @@ def _reassign_or_swap_out_partitioned_parameters_async(self, sub_group_id, updat @instrument_w_nvtx def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None: - # print("[DEBUG] partition_grads called") - buffers = [] - device_buffers = {} - buffer_numel_min = {} - buffer_numel_max = {} + completed_sub_groups = [] for param, grad_partition in zip(params_to_release, grad_partitions): i, dest_offset, _ = self.grad_position[self.get_param_id(param)] - if self.is_gradient_accumulation_boundary: - self.norm_for_param_grads[self.get_param_id(param)] = self._constant_buffered_norm2(grad_partition) - - buffer_numel = grad_partition.numel() - buffers.append(grad_partition) - - if i not in device_buffers: - device_buffers[i] = [] - device_buffers[i].append(grad_partition) - - if i not in buffer_numel_min: - buffer_numel_min[i] = dest_offset - buffer_numel_max[i] = dest_offset + buffer_numel + # Accumulate gradient into the grad_buffer, mirroring base class logic + grad_buffer = self._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[param.ds_id].narrow( + 0, 0, grad_partition.numel()) + if self.micro_step_id == 0: + grad_buffer.copy_(grad_partition, non_blocking=True) + grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True) + elif get_accelerator().on_accelerator(grad_buffer): + grad_buffer.add_(grad_partition.to(self.gradient_accumulation_dtype).view(grad_buffer.shape)) else: - buffer_numel_min[i] = min(buffer_numel_min[i], dest_offset) - buffer_numel_max[i] = max(buffer_numel_max[i], dest_offset + buffer_numel) + cuda_grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True) + cuda_grad_buffer.add_(grad_partition.to(self.gradient_accumulation_dtype).view(cuda_grad_buffer.shape)) + grad_buffer.copy_(cuda_grad_buffer, non_blocking=True) + grad_buffer = cuda_grad_buffer + + if self.is_gradient_accumulation_boundary: + self.norm_for_param_grads[self.get_param_id(param)] = self._constant_buffered_norm2(grad_buffer) - if self.is_gradient_accumulation_boundary: - for i in buffer_numel_min.keys(): fp32_grad_tensor = self.fp32_partitioned_groups_flat[i].grad.narrow( - 0, buffer_numel_min[i], buffer_numel_max[i] - buffer_numel_min[i]) - concatenated_buffer = torch.cat(device_buffers[i], dim=0).to(dtype=self.master_weights_and_grads_dtype) + 0, dest_offset, grad_buffer.numel()) + fp32_grad_tensor.copy_(grad_buffer.to(dtype=self.master_weights_and_grads_dtype), non_blocking=True) - if self.subgroup_to_device[i] == 'cpu': - # Trigger asynchronous CPU optimization + self.sub_group_grad_partition_counts[i] = self.sub_group_grad_partition_counts.get(i, 0) + 1 + if self.sub_group_grad_partition_counts[i] == self.sub_group_to_param_num[i]: + completed_sub_groups.append(i) + + if self.is_gradient_accumulation_boundary and completed_sub_groups: + get_accelerator().current_stream().synchronize() + for i in completed_sub_groups: + if self.subgroup_to_device[i] == 'cpu' and not self.clip_grad: param_group_id = self.sub_group_to_group_id[i] fp32_param = self.fp32_partitioned_groups_flat[i] + current_lr = self.optimizer.param_groups[param_group_id]['lr'] - self.superoffload_cpu_optimizer.async_step(param_group_id, i, fp32_param.data, - concatenated_buffer.data) + self.superoffload_cpu_optimizer.async_step(param_group_id, + i, + fp32_param.data, + fp32_param.grad.data, + lr=current_lr) self.async_cpuadam_num += 1 - # Check for completed async operations result = self.superoffload_cpu_optimizer.get_result() if result is not None: self._reassign_or_swap_out_partitioned_parameters_async(result[TaskKeys.SUB_GROUP_ID], result[ResultKeys.UPDATED_PARAM]) self.async_cpuadam_num -= 1 - fp32_grad_tensor.copy_(concatenated_buffer, non_blocking=True) - else: - fp32_grad_tensor.copy_(concatenated_buffer, non_blocking=True) - - # Clean up parameter gradients for param in params_to_release: if not get_accelerator().is_synchronized_device(): - param.grad.record_stream(get_accelerator().current_stream()) + if param.grad is not None: + param.grad.record_stream(get_accelerator().current_stream()) param.grad = None @instrument_w_nvtx @@ -210,14 +194,14 @@ def step(self, closure=None): """ Not supporting closure. """ - # Wait for any pending asynchronous CPU optimizer operations self._wait_for_async_operations() self._pre_step() self._partition_all_parameters() if self._overflow_check_and_loss_scale_update(): - self._handle_overflow_rollback() + if not self.clip_grad: + self._handle_overflow_rollback() return norm_groups = self._get_norm_groups() @@ -228,28 +212,45 @@ def step(self, closure=None): timer_names.add(OPTIMIZER_STEP_TIMER) self.timers(OPTIMIZER_STEP_TIMER).start() - if self.check_clip_grads(scaled_global_grad_norm): - self._handle_gradient_clipping(scaled_global_grad_norm) + if self.clip_grad: + self._step_with_clipping(scaled_global_grad_norm, timer_names) + else: + self._step_without_clipping(scaled_global_grad_norm, timer_names) + + self.timers(OPTIMIZER_STEP_TIMER).stop() + self._post_step(timer_names) + def _step_without_clipping(self, scaled_global_grad_norm, timer_names): + """Fast path: async CPU steps already completed during backward.""" for sub_group_id, group in enumerate(self.fp16_groups): - # Prepare optimizer states, gradients and fp32 parameters for update self._prepare_sub_group(sub_group_id, timer_names) + self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm) + self._optimizer_step(sub_group_id) + self._reassign_or_swap_out_partitioned_parameters(sub_group_id) + self._release_sub_group(sub_group_id, timer_names) - # Scale the fp32 gradients + def _step_with_clipping(self, scaled_global_grad_norm, timer_names): + """Clipping path: no async steps were done during backward, + so we unscale+clip first, then step all sub-groups.""" + for sub_group_id, group in enumerate(self.fp16_groups): + self._prepare_sub_group(sub_group_id, timer_names) self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm) - # Apply the optimizer step on the sub group and copy fp32 parameters to fp16 - self._optimizer_step(sub_group_id) + if self.subgroup_to_device[sub_group_id] == 'cpu': + param_group_id = self.sub_group_to_group_id[sub_group_id] + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + current_lr = self.optimizer.param_groups[param_group_id]['lr'] + self._sync_cpu_optimizer_step(param_group_id, + sub_group_id, + fp32_param.data, + fp32_param.grad.data, + lr=current_lr) + else: + self._optimizer_step(sub_group_id) - # Put fp16 parameters in appropriate location self._reassign_or_swap_out_partitioned_parameters(sub_group_id) - - # Release memory or swap out optimizer states of fp32 parameters self._release_sub_group(sub_group_id, timer_names) - self.timers(OPTIMIZER_STEP_TIMER).stop() - self._post_step(timer_names) - def _wait_for_async_operations(self, timeout_seconds=60): """Wait for all pending asynchronous CPU optimizer operations to complete with timeout error. @@ -316,13 +317,15 @@ def _sync_cpu_optimizer_step(self, fp32_param_data, fp32_grad_data, rollback: bool = False, + lr: float = None, timeout_seconds: int = 60): event_type = EventTypes.ROLLBACK if rollback else EventTypes.ADAM_STEP self.superoffload_cpu_optimizer.async_step(param_group_id, sub_group_id, fp32_param_data, fp32_grad_data, - rollback=rollback) + rollback=rollback, + lr=lr) # Wait for completion self._wait_for_single_async_result(event_type, timeout_seconds) @@ -357,11 +360,13 @@ def _handle_gradient_clipping(self, scaled_global_grad_norm): # Clip gradients and re-optimize self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm) + current_lr = self.optimizer.param_groups[param_group_id]['lr'] self._sync_cpu_optimizer_step(param_group_id, sub_group_id, fp32_param.data, fp32_param.grad.data, - rollback=False) + rollback=False, + lr=current_lr) @instrument_w_nvtx def check_clip_grads(self, total_norm): diff --git a/deepspeed/runtime/superoffload/superoffload_utils.py b/deepspeed/runtime/superoffload/superoffload_utils.py index e023730bd43e..c8a734b7c48c 100644 --- a/deepspeed/runtime/superoffload/superoffload_utils.py +++ b/deepspeed/runtime/superoffload/superoffload_utils.py @@ -22,6 +22,7 @@ class TaskKeys: PARAM_GROUP_ID = "param_group_id" SUB_GROUP_ID = "sub_group_id" ROLLBACK = "rollback" + LR = "lr" class ResultKeys: @@ -48,17 +49,32 @@ def superoffload_optimizer_worker(param_queue: mp.SimpleQueue, result_queue: mp. lr, betas, eps, weight_decay, and amsgrad parameters max_grad_numel: Maximum number of elements expected in gradient tensors """ - # Initialize dummy parameter for optimizer creation cpu_tensor = torch.randn(1, device="cpu") cpu_param = torch.nn.Parameter(cpu_tensor) try: + if isinstance(optimizer_config, list): + pg_configs = optimizer_config + else: + pg_configs = [optimizer_config] + + first_cfg = pg_configs[0] optimizer = DeepSpeedCPUAdam([cpu_param], - lr=optimizer_config["lr"], - betas=optimizer_config["betas"], - eps=optimizer_config["eps"], - weight_decay=optimizer_config["weight_decay"], - amsgrad=optimizer_config["amsgrad"]) + lr=first_cfg["lr"], + betas=first_cfg["betas"], + eps=first_cfg["eps"], + weight_decay=first_cfg["weight_decay"], + amsgrad=first_cfg["amsgrad"]) + for cfg in pg_configs[1:]: + dummy = torch.nn.Parameter(torch.randn(1, device="cpu")) + optimizer.add_param_group({ + "params": [dummy], + "lr": cfg["lr"], + "betas": cfg["betas"], + "eps": cfg["eps"], + "weight_decay": cfg["weight_decay"], + "amsgrad": cfg["amsgrad"], + }) except KeyError as e: error_msg = f"Missing required optimizer config key: {e}" logger.error(error_msg) @@ -81,6 +97,7 @@ def superoffload_optimizer_worker(param_queue: mp.SimpleQueue, result_queue: mp. param_group_id = task[TaskKeys.PARAM_GROUP_ID] sub_group_id = task[TaskKeys.SUB_GROUP_ID] rollback = task.get(TaskKeys.ROLLBACK, False) + task_lr = task.get(TaskKeys.LR, None) logger.debug(f"Processing param_group_id: {param_group_id}, sub_group_id: {sub_group_id}") @@ -88,6 +105,9 @@ def superoffload_optimizer_worker(param_queue: mp.SimpleQueue, result_queue: mp. del task[TaskKeys.PARAM_GRAD] task.clear() + if task_lr is not None: + optimizer.param_groups[param_group_id]['lr'] = task_lr + grad_numel = param_grad.numel() if grad_numel > max_grad_numel: error_msg = ( @@ -97,7 +117,7 @@ def superoffload_optimizer_worker(param_queue: mp.SimpleQueue, result_queue: mp. break param_grad_cpu = pinned_grad_buffer[:grad_numel].view_as(param_grad) - param_grad_cpu.copy_(param_grad, non_blocking=True) + param_grad_cpu.copy_(param_grad, non_blocking=False) fp32_param = torch.nn.Parameter(param_data) fp32_param.grad = param_grad_cpu @@ -202,20 +222,24 @@ def async_step(self, sub_group_id: int, fp32_param: torch.Tensor, fp32_grad: torch.Tensor, - rollback: bool = False) -> None: + rollback: bool = False, + lr: float = None) -> None: """ Queue parameter for optimization in the worker process. """ if not self.cpuadam_process.is_alive(): raise RuntimeError("Worker process is not alive") - self.param_queue.put({ + task = { TaskKeys.PARAM_DATA: fp32_param, TaskKeys.PARAM_GRAD: fp32_grad, TaskKeys.PARAM_GROUP_ID: param_group_id, TaskKeys.SUB_GROUP_ID: sub_group_id, TaskKeys.ROLLBACK: rollback, - }) + } + if lr is not None: + task[TaskKeys.LR] = lr + self.param_queue.put(task) def get_result(self, expected_event_type: str = None) -> Optional[Dict[str, Any]]: """ diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index def8d1db5653..79fbcb97a188 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -366,6 +366,13 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): Enable internal sanity checks, which could be useful for debugging """ + save_muon_momentum_buffer_in_memory: bool = False + """ + When using the Muon optimizer with ZeRO Stage 3, keeps the Muon momentum + buffer in GPU/CPU memory instead of swapping to NVMe with other optimizer + states. Only relevant when using NVMe offloading. + """ + leaf_module: DeepSpeedZeroLeafModuleConfig = Field(default_factory=DeepSpeedZeroLeafModuleConfig) """ Configuration for modules that should be treated as ZeRO3 leaf modules. diff --git a/deepspeed/runtime/zero/linear.py b/deepspeed/runtime/zero/linear.py index db95a5ac789c..7421fd10c5ef 100644 --- a/deepspeed/runtime/zero/linear.py +++ b/deepspeed/runtime/zero/linear.py @@ -16,12 +16,14 @@ #when implemented outside of torch.autograd.Function import math +import functools import torch from torch import Tensor from torch.nn.parameter import Parameter from torch.nn import init from torch.nn.modules.module import Module +from deepspeed.runtime.utils import noop_decorator from deepspeed import comm as dist from deepspeed.accelerator import get_accelerator @@ -31,6 +33,49 @@ def print_rank_0(message, debug=False, force=False): print(message) +def _get_legacy_autocast_decorators(device_type): + legacy_amp = getattr(getattr(torch, device_type, None), 'amp', None) + custom_fwd = getattr(legacy_amp, 'custom_fwd', None) + custom_bwd = getattr(legacy_amp, 'custom_bwd', None) + if custom_fwd is not None and custom_bwd is not None: + return custom_fwd, custom_bwd + return noop_decorator, noop_decorator + + +def _get_autocast_decorators(): + amp = getattr(torch, 'amp', None) + custom_fwd = getattr(amp, 'custom_fwd', None) + custom_bwd = getattr(amp, 'custom_bwd', None) + if custom_fwd is not None and custom_bwd is not None: + device_type = get_accelerator().device_name() + return functools.partial(custom_fwd, device_type=device_type), functools.partial(custom_bwd, + device_type=device_type) + return _get_legacy_autocast_decorators(get_accelerator().device_name()) + + +autocast_custom_fwd, autocast_custom_bwd = _get_autocast_decorators() + + +def _is_autocast_enabled(device_type): + try: + return torch.is_autocast_enabled(device_type) + except TypeError: + legacy_getter = getattr(torch, f'is_autocast_{device_type}_enabled', None) + if legacy_getter is not None: + return legacy_getter() + return torch.is_autocast_enabled() + + +def _get_autocast_dtype(device_type): + try: + return torch.get_autocast_dtype(device_type) + except TypeError: + legacy_getter = getattr(torch, f'get_autocast_{device_type}_dtype', None) + if legacy_getter is not None: + return legacy_getter() + return None + + class LinearFunctionForZeroStage3(torch.autograd.Function): @staticmethod @@ -51,8 +96,8 @@ def forward(input, weight, bias=None): @staticmethod def setup_context(ctx, inputs, output): device_type = get_accelerator().device_name() - ctx._dtype = torch.get_autocast_dtype(device_type) - ctx._fwd_used_autocast = torch.is_autocast_enabled(device_type) + ctx._dtype = _get_autocast_dtype(device_type) + ctx._fwd_used_autocast = _is_autocast_enabled(device_type) input, weight, bias = inputs[0], inputs[1], inputs[2] if len(inputs) > 2 else None ctx.save_for_backward(input, weight, bias) @@ -63,7 +108,7 @@ def backward(ctx, grad_output): # autocast state as forward — including explicitly disabling autocast # when forward did not use it, to guard against outer autocast regions. device_type = get_accelerator().device_name() - with torch.amp.autocast(device_type=device_type, enabled=ctx._fwd_used_autocast, dtype=ctx._dtype): + with torch.autocast(device_type=device_type, enabled=ctx._fwd_used_autocast, dtype=ctx._dtype): input, weight, bias = ctx.saved_tensors grad_input = grad_weight = grad_bias = None diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 25f5e27e30f8..c4f19f43de4f 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -36,6 +36,8 @@ from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS, PARTITION_COUNT, ZERO_STAGE, LOSS_SCALER from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.zero.muon.original_muon import muon_update +from deepspeed.runtime.zero.muon.muon_optimizer import MuonWithAuxAdam # Toggle this to true to enable correctness test # with gradient partitioning and without @@ -190,6 +192,7 @@ def __init__( log_trace_cache_warnings=False, enable_sanity_checks=False, cpuadam_cores_perc=0.8, + save_muon_momentum_buffer_in_memory=False, ): see_memory_usage("Stage 3 initialize beginning", force=False) @@ -329,12 +332,15 @@ def _enforce_optimizer_offload(): self.all2all_process_group = all2all_process_group self.reduce_scatter = reduce_scatter - + self.use_muon = isinstance(self.optimizer, MuonWithAuxAdam) + self.save_muon_momentum_buffer_in_memory = save_muon_momentum_buffer_in_memory + if self.use_muon and self.reduce_scatter: + raise ValueError("Muon and reduce scatter cannot be used together") + if self.use_muon and self.all2all_process_group is not None: + raise ValueError("Muon and all2all process group cannot be used together") self.dp_process_group = self.parameter_offload.dp_process_group self.sequence_parallel_size = groups._get_sequence_parallel_world_size() - self.all2all_process_group = all2all_process_group - self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights self.partition_count = dist.get_world_size(group=self.dp_process_group) @@ -385,6 +391,8 @@ def _enforce_optimizer_offload(): #a single 32-bit partition of the parallel partitioned parameters #that this process will update self.fp32_partitioned_groups_flat = [] + if self.use_muon and self.save_muon_momentum_buffer_in_memory: + self.muon_momentum_buffer_partitioned_groups_flat = {} self.next_swappable_fp32_partitioned_groups = [] # number of elements per partition in each group @@ -780,6 +788,19 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): param_groups: List[List[Parameter]] = tuple( self._create_fp16_sub_groups(param_group["params"]) for param_group in fp16_param_groups) + if self.use_muon: + self.sub_groups_using_muon = [] + self.muon_beta = None + for idx, param_group in enumerate(fp16_param_groups): + if getattr(param_group['params'][0], 'use_muon', False): + self.sub_groups_using_muon.extend([True] * len(param_groups[idx])) + group_beta = param_group['momentum'] + if self.muon_beta is not None and self.muon_beta != group_beta: + raise ValueError(f"All Muon parameter groups must have the same momentum (beta). " + f"Found {self.muon_beta} and {group_beta}.") + self.muon_beta = group_beta + else: + self.sub_groups_using_muon.extend([False] * len(param_groups[idx])) # bookkeeping related to param groups for param_group_idx, param_group in enumerate(param_groups): for sub_group in param_group: @@ -907,6 +928,20 @@ def _get_sub_group_partitions(self, sub_group_id): return sub_group_partitions + def _create_momentum_buffer(self, num_elements, i, ds_id): + if self.use_muon and self.sub_groups_using_muon[i]: + unpinned_fp32_buffer_momentum = torch.zeros(num_elements, + device=self.device, + dtype=self.communication_data_type) + unpinned_fp32_buffer_momentum.requires_grad = False + if self.fp32_partitioned_groups_flat[i] not in self.optimizer.state: + self.optimizer.state[self.fp32_partitioned_groups_flat[i]] = {} + self.optimizer.state[ + self.fp32_partitioned_groups_flat[i]]["momentum_buffer"] = unpinned_fp32_buffer_momentum + if self.save_muon_momentum_buffer_in_memory: + self.muon_momentum_buffer_partitioned_groups_flat[i] = unpinned_fp32_buffer_momentum + self.muon_momentum_buffer_partitioned_groups_flat[i].ds_id = ds_id + def _create_fp32_partitions(self): cpu_memory_usage = 0 cpu_memory_sub_groups = 0 @@ -948,6 +983,9 @@ def _create_fp32_partitions(self): self.fp32_partitioned_groups_flat[i].ds_id = ds_id nvme_memory_usage += (fp32_element_size * num_elements) num_swappable_partitions += 1 + if not (self.use_muon and self.sub_groups_using_muon[i] + and not self.save_muon_momentum_buffer_in_memory): + self._create_momentum_buffer(num_elements, i, ds_id) if self.params_in_nvme_and_cpu and tensor is None: num_swap_from_nvme_partitions += 1 @@ -979,20 +1017,24 @@ def _create_fp32_partitions(self): dtype=self.master_weights_and_grads_dtype) self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) self.fp32_partitioned_groups_flat.append(unpinned_fp32_buffer) + self._create_momentum_buffer(num_elements, i, ds_id) elif self.offload_optimizer: converted = self.fp16_partitioned_groups_flat[i].to(self.subgroup_to_device[i], dtype=self.master_weights_and_grads_dtype) self.fp32_partitioned_groups_flat.append(converted.clone().detach()) + self._create_momentum_buffer(num_elements, i, ds_id) elif self.fp16_partitioned_groups_flat[i].dtype == self.master_weights_and_grads_dtype and \ self.fp16_partitioned_groups_flat[i].device == self.device: # When torch autocast is enabled, weights in the provided model (and thus groups in the so-called # "fp16" partitioned groups) are already in and updated using fp32. In such cases we don't need # another copy of the weights. self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i]) + self._create_momentum_buffer(num_elements, i, ds_id) else: converted = self.fp16_partitioned_groups_flat[i].to(self.device, dtype=self.master_weights_and_grads_dtype) self.fp32_partitioned_groups_flat.append(converted.clone().detach()) + self._create_momentum_buffer(num_elements, i, ds_id) self.fp32_partitioned_groups_flat[i].ds_id = ds_id self.fp32_partitioned_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it @@ -1151,6 +1193,10 @@ def initialize_optimizer_states(self): if swappable_optimizer_subgroup: self._optimizer_states_and_gradient_swap_in(i, timer_names) + if self.use_muon and self.sub_groups_using_muon[i] and not self.save_muon_momentum_buffer_in_memory: + # Create momentum buffer after swap-in so swap files can be created on swap-out. + if "momentum_buffer" not in self.optimizer.state.get(self.fp32_partitioned_groups_flat[i], {}): + self._create_momentum_buffer(num_elements, i, self.fp32_partitioned_groups_flat[i].ds_id) if self.offload_optimizer and not swappable_optimizer_subgroup: subgroup_gradient_buffer = torch.zeros(num_elements, dtype=gradient_dtype, device=self.device) @@ -1430,6 +1476,122 @@ def __reduce_and_partition_ipg_grads(self, communication_data_type: torch.dtype) event.record() self.param_reduce_events.append(event) + def _apply_distributed_muon_update(self, communication_data_type: torch.dtype, buffer_to_reduce: Tensor): + """ + Update the momentum buffer of the parameters using muon. + Args: + communication_data_type: torch.dtype + buffer_to_reduce: Tensor + Returns: + None + """ + if not self.use_muon: + return + + params_by_group = {} + params_size_offset = 0 + for param in self.ipg_buckets[communication_data_type].params: + i, dest_offset, _ = self.grad_position[self.get_param_id(param)] + if self.sub_groups_using_muon[i]: + # copy the gradients back to the params in the ipg bucket for the muon update + param.grad.data.copy_(buffer_to_reduce.narrow(0, params_size_offset, + param.grad.numel()).view_as(param.grad), + non_blocking=False) + if i not in params_by_group: + params_by_group[i] = [] + params_by_group[i].append((param, dest_offset, params_size_offset)) + params_size_offset += param.grad.numel() + + # process muon updates per subgroup to avoid holding all parameters and states at once + for i, group_items in params_by_group.items(): + params = [param for param, _, _ in group_items] + if not params: + continue + + momentum_buffer = [] + if self._swappable_optimizer_subgroup(i) and not self.save_muon_momentum_buffer_in_memory: + # swap-in once, keep resident through update + writeback + self.optimizer_swapper.swap_in_optimizer_state(parameter=self.fp32_partitioned_groups_flat[i]) + if "momentum_buffer" not in self.optimizer.state.get(self.fp32_partitioned_groups_flat[i], {}): + self._create_momentum_buffer(self.fp16_partitioned_groups_flat_numel[i], i, + self.fp32_partitioned_groups_flat[i].ds_id) + state_buffer = self.optimizer.state[self.fp32_partitioned_groups_flat[i]]["momentum_buffer"] + for param, dest_offset, _ in group_items: + momentum_buffer.append(state_buffer.narrow(0, dest_offset, param.partition_numel()).clone()) + elif self.save_muon_momentum_buffer_in_memory: + state_buffer = self.muon_momentum_buffer_partitioned_groups_flat[i] + for param, dest_offset, _ in group_items: + momentum_buffer.append(state_buffer.narrow(0, dest_offset, param.partition_numel()).clone()) + else: + # Non-swappable optimizer (GPU/CPU): momentum buffer lives in optimizer state + if "momentum_buffer" not in self.optimizer.state.get(self.fp32_partitioned_groups_flat[i], {}): + self._create_momentum_buffer(self.fp16_partitioned_groups_flat_numel[i], i, + self.fp32_partitioned_groups_flat[i].ds_id) + state_buffer = self.optimizer.state[self.fp32_partitioned_groups_flat[i]]["momentum_buffer"] + for param, dest_offset, _ in group_items: + momentum_buffer.append(state_buffer.narrow(0, dest_offset, param.partition_numel()).clone()) + + gathered_params_momentums = self._partitioned_buffers_all_gather(params, momentum_buffer, + communication_data_type) + + world_sz = dist.get_world_size(self.dp_process_group) + rank = dist.get_rank(self.dp_process_group) + grads_pad = [param.grad for param in params] + [torch.empty_like(params[-1].grad)] * ( + (world_sz - len(params) % world_sz) % world_sz) + gathered_momentums_pad = gathered_params_momentums + [torch.empty_like(gathered_params_momentums[-1])] * ( + (world_sz - len(gathered_params_momentums) % world_sz) % world_sz) + grad_handles = [] + momentum_handles = [] + for base_i in range(len(params))[::world_sz]: + if base_i + rank < len(params): + param = params[base_i + rank] + g = param.grad + m = gathered_momentums_pad[base_i + rank] + update = muon_update(g, m, beta=self.muon_beta) + g.data.copy_(update, non_blocking=False) + grad_handle = dist.all_gather(grads_pad[base_i:base_i + world_sz], + grads_pad[base_i + rank], + async_op=True) + grad_handles.append(grad_handle) + momentum_handle = dist.all_gather(gathered_momentums_pad[base_i:base_i + world_sz], + gathered_momentums_pad[base_i + rank], + async_op=True) + momentum_handles.append(momentum_handle) + + for handle in momentum_handles: + handle.wait() + for idx, (param, dest_offset, _) in enumerate(group_items): + gathered_momentum = gathered_params_momentums[idx] + chunk_sz = math.ceil(param.grad.numel() / world_sz) + start_offset = rank * chunk_sz + end_offset = start_offset + chunk_sz + if end_offset > param.grad.numel(): + buffer_to_update = torch.zeros(chunk_sz, device=param.grad.device, dtype=param.grad.dtype) + buffer_to_update[:param.grad.numel() - + start_offset] = gathered_momentum.view(-1).data[start_offset:param.grad.numel()] + else: + buffer_to_update = gathered_momentum.view(-1).data[start_offset:end_offset] + if self._swappable_optimizer_subgroup(i) and not self.save_muon_momentum_buffer_in_memory: + self.optimizer.state[self.fp32_partitioned_groups_flat[i]]["momentum_buffer"].narrow( + 0, dest_offset, param.partition_numel()).data.copy_(buffer_to_update, non_blocking=False) + elif self.save_muon_momentum_buffer_in_memory: + self.muon_momentum_buffer_partitioned_groups_flat[i].narrow( + 0, dest_offset, param.partition_numel()).data.copy_(buffer_to_update, non_blocking=False) + # update the momentum buffer in the optimizer state + self.optimizer.state[self.fp32_partitioned_groups_flat[i]][ + "momentum_buffer"] = self.muon_momentum_buffer_partitioned_groups_flat[i] + else: + # Non-swappable optimizer (GPU/CPU): write directly to optimizer state + self.optimizer.state[self.fp32_partitioned_groups_flat[i]]["momentum_buffer"].narrow( + 0, dest_offset, param.partition_numel()).data.copy_(buffer_to_update, non_blocking=False) + if self._swappable_optimizer_subgroup(i) and not self.save_muon_momentum_buffer_in_memory: + self.optimizer_swapper.swap_out_optimizer_state(parameter=self.fp32_partitioned_groups_flat[i]) + for handle in grad_handles: + handle.wait() + for param, _, params_size_offset in group_items: + buffer_to_reduce.narrow(0, params_size_offset, param.grad.numel()).data.copy_(param.grad.view(-1), + non_blocking=False) + @instrument_w_nvtx def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor, communication_data_type: torch.dtype) -> List[Tensor]: @@ -1453,6 +1615,7 @@ def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor, grad_partitions = [] grad_offset_in_buffer = 0 + self._apply_distributed_muon_update(communication_data_type, buffer_to_reduce) for param in self.ipg_buckets[communication_data_type].params: grad = param.grad chunk_sz = math.ceil(grad.numel() / world_sz) @@ -1627,6 +1790,56 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L gradient_tensors=offload_fp32_gradients[i]) return buffers + def _partitioned_buffers_all_gather(self, params: List[Parameter], buffers_to_allgather: List[Tensor], + communication_data_type: torch.dtype): + """ + Allgather the partitioned buffers of the parameters to the global buffer. + Args: + params: List[Parameter] + buffers_to_allgather: List[Tensor] + communication_data_type: torch.dtype + Returns: + List[Tensor] + """ + + assert len(params) == len(buffers_to_allgather), "params and buffers_to_allgather must have the same length" + assert all(param.partition_numel() == buffer.numel() + for param, + buffer in zip(params, buffers_to_allgather)), \ + "params and buffers_to_allgather must have the same numel" + coalesced_buffer = instrument_w_nvtx(torch.cat)(buffers_to_allgather) + buffer_numel = coalesced_buffer.numel() + reduce_buffer = torch.empty(self.partition_count * buffer_numel, + dtype=communication_data_type, + device=params[0].device) + rearrange_buffer = torch.empty(self.partition_count * buffer_numel, + dtype=communication_data_type, + device=params[0].device) + my_rank = dist.get_rank(group=self.dp_process_group) + partition = reduce_buffer.narrow(0, buffer_numel * my_rank, buffer_numel) + partition.data.copy_(coalesced_buffer.data, non_blocking=False) + dist.all_gather_into_tensor(reduce_buffer, partition, group=self.dp_process_group) + param_partition_offsets = [0] + rearranged_offset = 0 + for idx, param in enumerate(params): + param_partition_offsets.append(param_partition_offsets[idx] + param.partition_numel()) + for idx, param in enumerate(params): + num_elements = param.partition_numel() + for partition_idx in range(self.partition_count): + sliced = reduce_buffer.narrow(0, buffer_numel * partition_idx + param_partition_offsets[idx], + num_elements) + rearrange_buffer.narrow(0, rearranged_offset, num_elements).copy_(sliced.data, non_blocking=False) + rearranged_offset += num_elements + param_full_offsets = [0] + for idx, param in enumerate(params): + # the offset is the sum of the numel of all the partitions of the parameter including padding + param_full_offsets.append(param_full_offsets[idx] + + buffers_to_allgather[idx].numel() * self.partition_count) + output = [] + for idx, param in enumerate(params): + output.append(rearrange_buffer.narrow(0, param_full_offsets[idx], param.ds_numel).view(param.ds_shape)) + return output + def reduce_ready_partitions_and_remove_grads(self, param): #print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True) self.reduce_independent_p_g_buckets_and_remove_grads(param) diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index 92365d87f6fc..dd0ff9bffd9e 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -36,9 +36,11 @@ toc_label: "Contents" | Fields | Value | Example | | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ---------------------------- | -| type | The optimizer name. DeepSpeed natively supports **Adam**, **AdamW**, **OneBitAdam**, **Lamb**, and **OneBitLamb** optimizers (See [here](https://deepspeed.readthedocs.io/en/latest/optimizers.html) for details) and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"` | +| type | The optimizer name. DeepSpeed natively supports **Adam**, **AdamW**, **OneBitAdam**, **Lamb**, **OneBitLamb**, and **Muon** optimizers (See [here](https://deepspeed.readthedocs.io/en/latest/optimizers.html) for details) and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"` | | params | Dictionary of parameters to instantiate optimizer. The parameter names must match the optimizer constructor signature (e.g., for [Adam](https://pytorch.org/docs/stable/optim.html#torch.optim.Adam)). | `{"lr": 0.001, "eps": 1e-8}` | +Muon optimizer is supported with ZeRO Stage 1, 2, and 3. To use Muon, set the optimizer name to `Muon`. The parameters applied for Muon are automatically determined by the matrix shape and name. For ZeRO Stage 3 with NVMe offloading, set `save_muon_momentum_buffer_in_memory` to `true` under `zero_optimization` to keep the Muon momentum buffer in GPU/CPU memory instead of swapping to NVMe. + Example of **optimizer** with Adam ```json @@ -62,6 +64,24 @@ The Adam optimizer also supports the following two params keys/values in additio | torch\_adam | Use torch's implementation of adam instead of our fused adam implementation | false | | adam\_w\_mode | Apply L2 regularization (also known as AdamW) | true | +Example of **optimizer** with Muon +If not set, muon_lr will default to lr. +```json +"optimizer": { + "type": "Muon", + "params": { + "lr": 0.001, + "momentum": 0.9, + "weight_decay": 0.0, + "muon_lr": 0.001 + } + }, + "zero_optimization": { + "stage": 3, + "save_muon_momentum_buffer_in_memory": true + } +``` + Another example of **optimizer** with 1-bit Adam specific parameters is as follows. ```json diff --git a/tests/unit/ops/muon/test_muon.py b/tests/unit/ops/muon/test_muon.py index f12cbb358a82..02594941cef0 100644 --- a/tests/unit/ops/muon/test_muon.py +++ b/tests/unit/ops/muon/test_muon.py @@ -13,21 +13,27 @@ if torch.half not in get_accelerator().supported_dtypes(): pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True) -# 'optimizer_type, zero_stage, lr, hidden_dim, nlayer' +# 'optimizer_type, zero_stage, lr, hidden_dim, nlayer, offload_optimizer, save_muon_momentum_buffer_in_memory' muon_configs = [] for optimizer_name in ['muon', 'adam']: - for stage in [1, 2]: + for stage in [1, 2, 3]: for lr in [0.01, 0.05]: for model_dim in [32, 128]: for nlayer in [5, 10]: - muon_configs.append([optimizer_name, stage, lr, model_dim, nlayer]) + for offload_optimizer in [True, False]: + for save_in_mem in ([True, False] if stage == 3 else [False]): + muon_configs.append( + [optimizer_name, stage, lr, model_dim, nlayer, offload_optimizer, save_in_mem]) -@pytest.mark.parametrize('optimizer_type, zero_stage, lr, hidden_dim, nlayer', muon_configs) +@pytest.mark.parametrize( + 'optimizer_type, zero_stage, lr, hidden_dim, nlayer, offload_optimizer, save_muon_momentum_buffer_in_memory', + muon_configs) class TestMuonConfigs(DistributedTest): - def test(self, optimizer_type, zero_stage, lr, hidden_dim, nlayer): + def test(self, optimizer_type, zero_stage, lr, hidden_dim, nlayer, offload_optimizer, + save_muon_momentum_buffer_in_memory): optimizer_params = {"lr": lr} batch_size = 8 config_dict = { @@ -42,8 +48,16 @@ def test(self, optimizer_type, zero_stage, lr, hidden_dim, nlayer): }, "zero_optimization": { "stage": zero_stage, - } + "reduce_scatter": False, + "save_muon_momentum_buffer_in_memory": save_muon_momentum_buffer_in_memory, + }, } + if offload_optimizer: + config_dict["zero_optimization"]["offload_optimizer"] = { + "device": "cpu", + "pin_memory": True, + } + # Perform a few training steps to ensure the optimizer works correctly model = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayer) diff --git a/tests/unit/runtime/test_autocast.py b/tests/unit/runtime/test_autocast.py index 682a98ae38bb..21ffc9bfbb4d 100644 --- a/tests/unit/runtime/test_autocast.py +++ b/tests/unit/runtime/test_autocast.py @@ -3,10 +3,14 @@ # DeepSpeed Team +import functools + import pytest import torch +import deepspeed.runtime.zero.linear as zero_linear from deepspeed.runtime.zero.linear import LinearModuleForZeroStage3 from deepspeed.accelerator import get_accelerator +from deepspeed.utils.torch import required_torch_version from unit.common import DistributedTest @@ -56,3 +60,30 @@ def test_autocast_linear(self, tmpdir, half_input, half_weight): with torch.amp.autocast(device_type=get_accelerator().device_name()): output = ds_linear(input) assert output.dtype == torch.half or output.dtype == torch.bfloat16 + + +def test_get_autocast_decorators_use_torch_amp_on_torch_2_4_or_newer(): + if not required_torch_version(min_version=2.4): + pytest.skip('torch.amp.custom_fwd/custom_bwd are only available on torch >= 2.4') + + device_type = get_accelerator().device_name() + + assert isinstance(zero_linear.autocast_custom_fwd, functools.partial) + assert isinstance(zero_linear.autocast_custom_bwd, functools.partial) + assert zero_linear.autocast_custom_fwd.func is torch.amp.custom_fwd + assert zero_linear.autocast_custom_bwd.func is torch.amp.custom_bwd + assert zero_linear.autocast_custom_fwd.keywords == {'device_type': device_type} + assert zero_linear.autocast_custom_bwd.keywords == {'device_type': device_type} + + +def test_get_autocast_decorators_use_legacy_amp_or_noop_before_torch_2_4(): + if required_torch_version(min_version=2.4): + pytest.skip('legacy AMP fallback only applies on torch < 2.4') + + device_type = get_accelerator().device_name() + legacy_amp = getattr(getattr(torch, device_type, None), 'amp', None) + expected_custom_fwd = getattr(legacy_amp, 'custom_fwd', zero_linear.noop_decorator) + expected_custom_bwd = getattr(legacy_amp, 'custom_bwd', zero_linear.noop_decorator) + + assert zero_linear.autocast_custom_fwd is expected_custom_fwd + assert zero_linear.autocast_custom_bwd is expected_custom_bwd