-
Notifications
You must be signed in to change notification settings - Fork 301
Ulysses position_ids pre-gather, NUMA rewrite, and operational improvements #1371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3,7 +3,7 @@ | |||||||||
| import os | ||||||||||
| import socket | ||||||||||
| from collections import defaultdict | ||||||||||
| from ctypes import CDLL, POINTER, Structure, c_char_p, c_int, c_ulong, c_void_p | ||||||||||
| from ctypes import CDLL, c_int | ||||||||||
| from datetime import timedelta | ||||||||||
| from pathlib import Path | ||||||||||
| from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union | ||||||||||
|
|
@@ -176,54 +176,68 @@ def get_master_addr_port(self): | |||||||||
| return self._master_addr, self._master_port | ||||||||||
|
|
||||||||||
| # TODO(tgriggs): For numa affinity, pass in the Worker._local_rank for the second arg here. Distinguish 'rank' and 'local_rank' differ here. | ||||||||||
| def _set_numa_affinity(self, rank): | ||||||||||
| def _set_numa_affinity(self, rank): # noqa: ARG002 — rank kept for API compat, binding uses self._local_rank | ||||||||||
| def local_rank_to_real_gpu_id(local_rank): | ||||||||||
| cuda_visible_devices = [ | ||||||||||
| int(x) for x in os.environ.get("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7").split(",") | ||||||||||
| ] | ||||||||||
| return cuda_visible_devices[local_rank] | ||||||||||
|
|
||||||||||
| rank = local_rank_to_real_gpu_id(rank) | ||||||||||
| return cuda_visible_devices[local_rank % len(cuda_visible_devices)] | ||||||||||
|
|
||||||||||
|
||||||||||
| # Mark 'rank' as intentionally accepted for API compatibility, even if not used in logic yet. | |
| _ = rank |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you tested this on different GPUs?
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,11 +8,13 @@ | |
| The trainer interacts with the worker dispatch if all models are always on GPU. | ||
| """ | ||
|
|
||
| import time | ||
| from dataclasses import dataclass | ||
| from typing import Any, Dict, List, Optional | ||
|
|
||
| import ray | ||
| from ray import ObjectRef | ||
| from loguru import logger | ||
|
|
||
| from skyrl.backends.skyrl_train.distributed.dispatch import ( | ||
| MeshDispatch, | ||
|
|
@@ -70,6 +72,10 @@ def __init__( | |
| # GPU state tracking (only matters when colocated) | ||
| self._gpu_state: Dict[str, GPUState] = {name: GPUState() for name in self._actor_groups.keys()} | ||
|
|
||
| def get_dp_size(self, model: str) -> int: | ||
| """Get dp_size for a specific model.""" | ||
| return self._actor_groups[model].actor_infos[0].rank.dp_size | ||
|
|
||
| def get_lcm_dp_size(self) -> int: | ||
| """Get LCM of all models' dp_size.""" | ||
| import math | ||
|
|
@@ -266,17 +272,25 @@ def forward_backward_from_staged( | |
| end_idx=end_idx, | ||
| **kwargs, | ||
| ) | ||
| t0 = time.time() | ||
| statuses = ray.get(refs) | ||
| logger.info(f"[dispatch] ray.get(forward_backward) done in {time.time() - t0:.1f}s") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's clean up debug logs |
||
|
|
||
| t0 = time.time() | ||
| self._save_memory_snapshot(model, "forward_backward") | ||
| logger.info(f"[dispatch] _save_memory_snapshot(forward_backward) done in {time.time() - t0:.1f}s") | ||
| return statuses[0] | ||
|
Comment on lines
+275
to
282
|
||
|
|
||
| def optim_step(self, model: str) -> Optional[float]: | ||
| """Run optimizer step. Model should already be on GPU from forward_backward.""" | ||
| t0 = time.time() | ||
| refs = self._actor_groups[model].async_run_ray_method("pass_through", "optim_step") | ||
| grad_norms = ray.get(refs) | ||
| logger.info(f"[dispatch] ray.get(optim_step) done in {time.time() - t0:.1f}s") | ||
|
|
||
| t0 = time.time() | ||
| self._save_memory_snapshot(model, "optim_step") | ||
| logger.info(f"[dispatch] _save_memory_snapshot(optim_step) done in {time.time() - t0:.1f}s") | ||
|
Comment on lines
276
to
+293
|
||
| return grad_norms[0] | ||
|
|
||
| def set_lr(self, model: str, learning_rate: float) -> None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |||||
| import math | ||||||
| import os | ||||||
| import shutil | ||||||
| import time | ||||||
| from collections import defaultdict | ||||||
| from dataclasses import asdict | ||||||
| from pathlib import Path | ||||||
|
|
@@ -1061,8 +1062,9 @@ def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[s | |||||
| """ | ||||||
| Execute training step for FSDP strategy using forward_backward + optim_step. | ||||||
|
|
||||||
| The trainer loops over epochs and mini-batches. Workers handle micro-batching | ||||||
| internally for gradient accumulation (memory efficiency). | ||||||
| Dispatches individual micro-batches to workers for per-micro-batch progress | ||||||
| visibility. Gradients accumulate on workers across micro-batches; optim_step | ||||||
| is called once per mini-batch to scale and apply. | ||||||
|
|
||||||
| Uses staged data approach: the full batch is put in Ray object store once, | ||||||
| and workers fetch + slice locally to avoid repeated serialization. | ||||||
|
|
@@ -1081,6 +1083,12 @@ def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[s | |||||
| else: | ||||||
| mini_batch_size = self.cfg.trainer.critic_mini_batch_size * n_samples | ||||||
|
|
||||||
| # Micro-batch size per dispatch = micro_bs_per_gpu * dp_size | ||||||
| # so each worker gets exactly micro_train_batch_size_per_gpu samples | ||||||
| micro_bs_per_gpu = self.cfg.trainer.micro_train_batch_size_per_gpu | ||||||
| dp_size = self.dispatch.get_dp_size(model) | ||||||
| micro_dispatch_size = micro_bs_per_gpu * dp_size | ||||||
|
|
||||||
| all_metrics: Dict[str, List[float]] = defaultdict(list) | ||||||
|
|
||||||
| # Stage full batch in object store ONCE to avoid repeated serialization | ||||||
|
|
@@ -1090,16 +1098,42 @@ def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[s | |||||
| for _epoch in range(self.cfg.trainer.update_epochs_per_batch): | ||||||
| num_mini_batches = len(data) // mini_batch_size | ||||||
| for local_step in range(num_mini_batches): | ||||||
| start_idx = local_step * mini_batch_size | ||||||
| end_idx = (local_step + 1) * mini_batch_size | ||||||
| mb_start_idx = local_step * mini_batch_size | ||||||
| mb_end_idx = (local_step + 1) * mini_batch_size | ||||||
|
|
||||||
| # Dispatch individual micro-batches for progress visibility | ||||||
| num_micro_batches = math.ceil((mb_end_idx - mb_start_idx) / micro_dispatch_size) | ||||||
| t0 = time.time() | ||||||
| logger.info( | ||||||
| f"[{model}] mini-batch {local_step + 1}/{num_mini_batches}: " | ||||||
| f"dispatching {num_micro_batches} micro-batches " | ||||||
| f"(micro_bs={micro_bs_per_gpu}, dp={dp_size})" | ||||||
| ) | ||||||
|
|
||||||
| for ub_idx in range(num_micro_batches): | ||||||
| ub_start = mb_start_idx + ub_idx * micro_dispatch_size | ||||||
| ub_end = min(ub_start + micro_dispatch_size, mb_end_idx) | ||||||
|
|
||||||
|
Comment on lines
+1104
to
+1116
|
||||||
| ub_t0 = time.time() | ||||||
| status = self.dispatch.forward_backward_from_staged(model, data_ref, ub_start, ub_end) | ||||||
| ub_elapsed = time.time() - ub_t0 | ||||||
|
|
||||||
| elapsed_total = time.time() - t0 | ||||||
| avg_per_ub = elapsed_total / (ub_idx + 1) | ||||||
| remaining = avg_per_ub * (num_micro_batches - ub_idx - 1) | ||||||
| logger.info( | ||||||
|
||||||
| logger.info( | |
| logger.debug( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There have been changes on main since this PR was created, most importantly #1376 to avoid serialization overhead. I believe these changes are not needed.
Also, how would this work for Megatron where we don't write the micro batch loop explicitly ? We let Megatron handle the micro batching loop for PP
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!