Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions skyrl/backends/skyrl_train/distributed/ulysses/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,33 @@
get_ulysses_sequence_parallel_world_size,
)

# Module-level storage for position_ids (both sliced and all-gathered versions).
# Pre-computed in model_wrapper.py via set_ulysses_position_ids() before the model call,
# so that _ulysses_flash_attention_forward can use the cached version instead of running
# NCCL all_gather during gradient checkpointing backward recompute.
# Safe as a global because each Ray worker is a separate process with a single training thread.
_ulysses_position_ids_sliced: Optional[torch.Tensor] = None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

_ulysses_position_ids_gathered: Optional[torch.Tensor] = None


def set_ulysses_position_ids(position_ids: Optional[torch.Tensor]):
"""Store position_ids and pre-compute all-gathered version for use by _ulysses_flash_attention_forward.

Must be called outside the checkpointed region (i.e., in model_wrapper.py before the model call).
"""
global _ulysses_position_ids_sliced, _ulysses_position_ids_gathered
_ulysses_position_ids_sliced = position_ids
if position_ids is not None:
sp_size = get_ulysses_sequence_parallel_world_size()
if sp_size > 1:
position_ids_list = [torch.empty_like(position_ids) for _ in range(sp_size)]
torch.distributed.all_gather(position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group())
_ulysses_position_ids_gathered = torch.concat(position_ids_list, dim=-1)
else:
_ulysses_position_ids_gathered = position_ids
else:
_ulysses_position_ids_gathered = None


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
Expand Down Expand Up @@ -75,6 +102,10 @@ def _ulysses_flash_attention_forward(

########## AlltoAll for Ulysses ##########
if ulysses_sp_size > 1:
# For models that don't pass position_ids through decoder layers (e.g., GraniteMoeHybrid),
# fall back to the pre-gathered global set by model_wrapper.py.
if position_ids is None:
position_ids = _ulysses_position_ids_sliced
assert position_ids is not None, "position_ids is required for Ulysses sequence parallelism"
# NOTE: repeat kv heads to be divided by sequence parallel. Instead of repeating nheads_q//nheads_k,
# we choose to repeat sp_size//nheads_k, since flash_attention supports MQA/GQA.
Expand All @@ -91,13 +122,17 @@ def _ulysses_flash_attention_forward(
key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)

# TODO: all_gather position_ids because `prepare_fa2_from_position_ids` needs it, we can eliminate
# this all_gather by passing cu_seq_lens_q, cu_seq_lens_k, max_length_k, max_length_q explicitly.
# https://github.com/huggingface/transformers/pull/33932
# Use pre-gathered position_ids to avoid NCCL all_gather during gradient
# checkpointing backward recompute. The pre-gathered version is computed once
# in model_wrapper.py via set_ulysses_position_ids() before the model call.
# (bsz, seq_len/n) -> (bsz, seq_len)
position_ids_list = [torch.empty_like(position_ids) for _ in range(ulysses_sp_size)]
torch.distributed.all_gather(position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group())
position_ids = torch.concat(position_ids_list, dim=-1)
if _ulysses_position_ids_gathered is not None:
position_ids = _ulysses_position_ids_gathered
else:
# Fallback: inline all_gather (only for non-checkpointed paths)
position_ids_list = [torch.empty_like(position_ids) for _ in range(ulysses_sp_size)]
torch.distributed.all_gather(position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group())
position_ids = torch.concat(position_ids_list, dim=-1)

if attention_mask is not None:
# all gather attention mask
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ def setup_envvars_for_vllm(kwargs, bundle_indices):
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
os.environ.pop("ROCR_VISIBLE_DEVICES", None)
os.environ.pop("HIP_VISIBLE_DEVICES", None)
# Ensure RAY_ADDRESS is set so that vLLM's EngineCore subprocess can
# connect back to the Ray cluster and query placement group state.
# Without this, the subprocess fails with KeyError: 'bundles' when
# accessing placement_group_table() because it can't reach the GCS.
if "RAY_ADDRESS" not in os.environ:
os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address
elif noset_visible_devices:
# We need to set CUDA_VISIBLE_DEVICES to the ray assigned GPU
# when the distributed_executor_backend is not ray/mp and
Expand Down
10 changes: 10 additions & 0 deletions skyrl/backends/skyrl_train/workers/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from peft.tuners.lora import LoraLayer
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, BitsAndBytesConfig

from skyrl.backends.skyrl_train.distributed.ulysses.monkey_patch import set_ulysses_position_ids
from skyrl.backends.skyrl_train.distributed.ulysses.utils import (
gather_outputs_and_unpad,
ulysses_pad_and_slice_inputs,
Expand Down Expand Up @@ -319,6 +320,12 @@ def forward(
sequences_rolled, None, None, self.sequence_parallel_size
)

# Store position_ids in module-level globals (safe: single-threaded Ray worker process)
# for architectures that don't propagate position_ids through decoder layers.
# Must be set here (outside the model) to survive gradient checkpointing backward recompute.
if self.sequence_parallel_size > 1:
set_ulysses_position_ids(position_ids_fwd)

# NOTE (sumanthrh): Once we have position_ids, we don't need attention mask with flash attention.
if self.use_sample_packing and self.attn_implementation == "flash_attention_2":
# NOTE (sumanthrh): Don't use attention mask. position_ids is enough.
Expand Down Expand Up @@ -479,6 +486,9 @@ def forward(
input_ids_fwd, position_ids_fwd, attention_mask_fwd, self.sequence_parallel_size
)

if self.sequence_parallel_size > 1:
set_ulysses_position_ids(position_ids_fwd)

if self.sequence_parallel_size > 1 and self.config._attn_implementation == "flash_attention_2":
outputs = getattr(self, self.base_model_prefix)(input_ids_fwd, position_ids=position_ids_fwd)
else:
Expand Down
76 changes: 45 additions & 31 deletions skyrl/backends/skyrl_train/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import socket
from collections import defaultdict
from ctypes import CDLL, POINTER, Structure, c_char_p, c_int, c_ulong, c_void_p
from ctypes import CDLL, c_int
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
Expand Down Expand Up @@ -176,54 +176,68 @@ def get_master_addr_port(self):
return self._master_addr, self._master_port

# TODO(tgriggs): For numa affinity, pass in the Worker._local_rank for the second arg here. Distinguish 'rank' and 'local_rank' differ here.
def _set_numa_affinity(self, rank):
def _set_numa_affinity(self, rank): # noqa: ARG002 — rank kept for API compat, binding uses self._local_rank
def local_rank_to_real_gpu_id(local_rank):
cuda_visible_devices = [
int(x) for x in os.environ.get("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7").split(",")
]
return cuda_visible_devices[local_rank]

rank = local_rank_to_real_gpu_id(rank)
return cuda_visible_devices[local_rank % len(cuda_visible_devices)]

Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_set_numa_affinity() still takes a rank parameter, but it’s no longer used after the rewrite (binding is based on self._local_rank). Please remove the unused parameter (and update call sites) or use it consistently to avoid confusion and keep linters happy.

Suggested change
# Mark 'rank' as intentionally accepted for API compatibility, even if not used in logic yet.
_ = rank

Copilot uses AI. Check for mistakes.
global _SET_AFFINITY
if _SET_AFFINITY:
return

from ctypes.util import find_library

class bitmask_t(Structure):
_fields_ = [
("size", c_ulong),
("maskp", POINTER(c_ulong)),
]

try:
LIBNUMA = CDLL(find_library("numa"))
except Exception as e:
logger.error(f"Skipping NUMA affinity setup because libnuma is not installed: {e}")
_SET_AFFINITY = True
return

LIBNUMA.numa_parse_nodestring.argtypes = [c_char_p]
LIBNUMA.numa_parse_nodestring.restype = POINTER(bitmask_t)
LIBNUMA.numa_run_on_node_mask.argtypes = [POINTER(bitmask_t)]
LIBNUMA.numa_run_on_node_mask.restype = c_int
LIBNUMA.numa_set_membind.argtypes = [POINTER(bitmask_t)]
LIBNUMA.numa_set_membind.restype = c_void_p
LIBNUMA.numa_num_configured_nodes.argtypes = []
LIBNUMA.numa_num_configured_nodes.restype = c_int

def numa_bind(nid: int):
bitmask = LIBNUMA.numa_parse_nodestring(bytes(str(nid), "ascii"))
LIBNUMA.numa_run_on_node_mask(bitmask)
LIBNUMA.numa_set_membind(bitmask)

numa_nodes = LIBNUMA.numa_num_configured_nodes()
if numa_nodes <= 0:
numa_nodes = 1
num_gpu_pre_numa_node = max(1, 8 // numa_nodes)
target_nid = min(numa_nodes - 1, self._local_rank // num_gpu_pre_numa_node)
numa_bind(target_nid)
# Check NUMA is actually functional
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested this on different GPUs?

LIBNUMA.numa_available.restype = c_int
LIBNUMA.numa_available.argtypes = []
if LIBNUMA.numa_available() < 0:
logger.warning("NUMA not available on this system, skipping affinity")
_SET_AFFINITY = True
return

# Use numa_max_node() NOT numa_num_configured_nodes().
# On NVLink/GB200 systems, numa_num_configured_nodes() incorrectly counts
# virtual NVLink NUMA IDs (e.g. 2,10,18,26) giving wrong total (e.g. 4).
# numa_max_node() returns the real highest physical NUMA node index (e.g. 1).
LIBNUMA.numa_max_node.restype = c_int
LIBNUMA.numa_max_node.argtypes = []
max_node = LIBNUMA.numa_max_node() # e.g. 1 → real nodes are 0 and 1
if max_node < 0:
logger.warning("numa_max_node() returned <0, skipping affinity")
_SET_AFFINITY = True
return
real_numa_nodes = max_node + 1 # e.g. 2

# Use integer API — avoids bitmask pointer corruption that causes segfaults
LIBNUMA.numa_run_on_node.restype = c_int
LIBNUMA.numa_run_on_node.argtypes = [c_int]
LIBNUMA.numa_set_preferred.restype = None
LIBNUMA.numa_set_preferred.argtypes = [c_int]

real_gpu_id = local_rank_to_real_gpu_id(self._local_rank)
total_gpus = len(os.environ.get("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7").split(","))
num_gpus_per_numa = max(1, total_gpus // real_numa_nodes)
# Clamp to [0, max_node] — guaranteed safe
target_nid = min(max_node, real_gpu_id // num_gpus_per_numa)

logger.info(
f"NUMA affinity: local_rank={self._local_rank}, real_gpu={real_gpu_id}, "
f"real_numa_nodes={real_numa_nodes}, max_node={max_node}, target={target_nid}"
)

ret = LIBNUMA.numa_run_on_node(target_nid)
if ret != 0:
logger.warning(f"numa_run_on_node({target_nid}) returned {ret}, may not have bound")
LIBNUMA.numa_set_preferred(target_nid)
_SET_AFFINITY = True


Expand Down
14 changes: 14 additions & 0 deletions skyrl/backends/skyrl_train/workers/worker_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
The trainer interacts with the worker dispatch if all models are always on GPU.
"""

import time
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import ray
from ray import ObjectRef
from loguru import logger

from skyrl.backends.skyrl_train.distributed.dispatch import (
MeshDispatch,
Expand Down Expand Up @@ -70,6 +72,10 @@ def __init__(
# GPU state tracking (only matters when colocated)
self._gpu_state: Dict[str, GPUState] = {name: GPUState() for name in self._actor_groups.keys()}

def get_dp_size(self, model: str) -> int:
"""Get dp_size for a specific model."""
return self._actor_groups[model].actor_infos[0].rank.dp_size

def get_lcm_dp_size(self) -> int:
"""Get LCM of all models' dp_size."""
import math
Expand Down Expand Up @@ -266,17 +272,25 @@ def forward_backward_from_staged(
end_idx=end_idx,
**kwargs,
)
t0 = time.time()
statuses = ray.get(refs)
logger.info(f"[dispatch] ray.get(forward_backward) done in {time.time() - t0:.1f}s")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's clean up debug logs


t0 = time.time()
self._save_memory_snapshot(model, "forward_backward")
logger.info(f"[dispatch] _save_memory_snapshot(forward_backward) done in {time.time() - t0:.1f}s")
return statuses[0]
Comment on lines +275 to 282
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the new per-micro-batch dispatch, this path now calls _save_memory_snapshot() for every micro-batch. Even when workers have record_memory disabled (no-op), this still incurs extra Ray RPC + synchronization overhead. Consider gating _save_memory_snapshot calls on the relevant config (e.g., trainer.policy.record_memory) or a dispatch-level flag so the default path avoids per-micro-batch RPCs.

Copilot uses AI. Check for mistakes.

def optim_step(self, model: str) -> Optional[float]:
"""Run optimizer step. Model should already be on GPU from forward_backward."""
t0 = time.time()
refs = self._actor_groups[model].async_run_ray_method("pass_through", "optim_step")
grad_norms = ray.get(refs)
logger.info(f"[dispatch] ray.get(optim_step) done in {time.time() - t0:.1f}s")

t0 = time.time()
self._save_memory_snapshot(model, "optim_step")
logger.info(f"[dispatch] _save_memory_snapshot(optim_step) done in {time.time() - t0:.1f}s")
Comment on lines 276 to +293
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These new logger.info() calls will execute for every forward_backward/optim_step dispatch (now per micro-batch), which can flood logs and add overhead. Consider using debug-level logging and/or sampling (e.g., log only every N calls or when latency exceeds a threshold).

Copilot uses AI. Check for mistakes.
return grad_norms[0]

def set_lr(self, model: str, learning_rate: float) -> None:
Expand Down
50 changes: 42 additions & 8 deletions skyrl/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import math
import os
import shutil
import time
from collections import defaultdict
from dataclasses import asdict
from pathlib import Path
Expand Down Expand Up @@ -1061,8 +1062,9 @@ def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[s
"""
Execute training step for FSDP strategy using forward_backward + optim_step.

The trainer loops over epochs and mini-batches. Workers handle micro-batching
internally for gradient accumulation (memory efficiency).
Dispatches individual micro-batches to workers for per-micro-batch progress
visibility. Gradients accumulate on workers across micro-batches; optim_step
is called once per mini-batch to scale and apply.

Uses staged data approach: the full batch is put in Ray object store once,
and workers fetch + slice locally to avoid repeated serialization.
Expand All @@ -1081,6 +1083,12 @@ def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[s
else:
mini_batch_size = self.cfg.trainer.critic_mini_batch_size * n_samples

# Micro-batch size per dispatch = micro_bs_per_gpu * dp_size
# so each worker gets exactly micro_train_batch_size_per_gpu samples
micro_bs_per_gpu = self.cfg.trainer.micro_train_batch_size_per_gpu
dp_size = self.dispatch.get_dp_size(model)
micro_dispatch_size = micro_bs_per_gpu * dp_size

all_metrics: Dict[str, List[float]] = defaultdict(list)

# Stage full batch in object store ONCE to avoid repeated serialization
Expand All @@ -1090,16 +1098,42 @@ def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[s
for _epoch in range(self.cfg.trainer.update_epochs_per_batch):
num_mini_batches = len(data) // mini_batch_size
for local_step in range(num_mini_batches):
start_idx = local_step * mini_batch_size
end_idx = (local_step + 1) * mini_batch_size
mb_start_idx = local_step * mini_batch_size
mb_end_idx = (local_step + 1) * mini_batch_size

# Dispatch individual micro-batches for progress visibility
num_micro_batches = math.ceil((mb_end_idx - mb_start_idx) / micro_dispatch_size)
t0 = time.time()
logger.info(
f"[{model}] mini-batch {local_step + 1}/{num_mini_batches}: "
f"dispatching {num_micro_batches} micro-batches "
f"(micro_bs={micro_bs_per_gpu}, dp={dp_size})"
)

for ub_idx in range(num_micro_batches):
ub_start = mb_start_idx + ub_idx * micro_dispatch_size
ub_end = min(ub_start + micro_dispatch_size, mb_end_idx)

Comment on lines +1104 to +1116
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loop uses math.ceil() + min() for micro-batch slicing. Given MeshDispatch.dispatch_from_staged() requires (end_idx-start_idx) to be divisible by dp_size, it may be safer to enforce the stronger invariant here as well (e.g., assert (mb_end_idx-mb_start_idx) % micro_dispatch_size == 0 and use integer division) so misconfigured batch sizes fail fast with a clearer error than a later dispatch assertion.

Copilot uses AI. Check for mistakes.
ub_t0 = time.time()
status = self.dispatch.forward_backward_from_staged(model, data_ref, ub_start, ub_end)
ub_elapsed = time.time() - ub_t0

elapsed_total = time.time() - t0
avg_per_ub = elapsed_total / (ub_idx + 1)
remaining = avg_per_ub * (num_micro_batches - ub_idx - 1)
logger.info(
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per-micro-batch logger.info() inside the inner loop can generate very high log volume (and driver I/O overhead) for large runs. Consider making these debug-level, rate-limiting (e.g., every N micro-batches), or gating behind a config flag so production training isn’t slowed by logging.

Suggested change
logger.info(
logger.debug(

Copilot uses AI. Check for mistakes.
f"[{model}] micro-batch {ub_idx + 1}/{num_micro_batches} "
f"{ub_elapsed:.1f}s | elapsed {elapsed_total:.1f}s | ~{remaining:.0f}s left"
)

Comment on lines +1107 to 1128
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There have been changes on main since this PR was created, most importantly #1376 to avoid serialization overhead. I believe these changes are not needed.

Also, how would this work for Megatron where we don't write the micro batch loop explicitly ? We let Megatron handle the micro batching loop for PP

# Workers fetch from object store and slice locally
status = self.dispatch.forward_backward_from_staged(model, data_ref, start_idx, end_idx)
for k, v in status.items():
all_metrics[k].append(v)
for k, v in status.items():
all_metrics[k].append(v)

# Optimizer step after each mini batch
logger.info(f"[{model}] starting optim_step...")
optim_t0 = time.time()
grad_norm = self.dispatch.optim_step(model)
logger.info(f"[{model}] optim_step completed in {time.time() - optim_t0:.1f}s, grad_norm={grad_norm}")
if grad_norm is not None:
all_metrics["grad_norm"].append(grad_norm)

Expand Down
Loading