Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
589c150
Move loss reduction normalization to trainer-level advantage scaling,…
justinvyu Mar 9, 2026
333f31a
Add token_mean_baseline loss reduction for mean-of-microbatch-means c…
justinvyu Mar 9, 2026
aaaba4c
fix assertion
justinvyu Mar 9, 2026
a121360
Update tests for sum-based reduce_loss and dp_size scaling changes
justinvyu Mar 10, 2026
15de89a
Merge remote-tracking branch 'upstream/main' into token_mean_loss_red…
justinvyu Mar 17, 2026
e3842c3
lint
justinvyu Mar 17, 2026
13bfe80
fix tests
justinvyu Mar 17, 2026
e76bece
Refactor advantage normalization: fix z-score propagation, skip for c…
justinvyu Mar 20, 2026
0192e8e
token_mean_baseline -> token_mean_legacy
justinvyu Mar 20, 2026
4ee0b31
Extract apply_loss_reduction_to_advantages_minibatch to ppo_utils and…
justinvyu Mar 20, 2026
c8f06cc
Fix metric reporting: remove dp_size scaling, separate micro-batch vs…
justinvyu Mar 25, 2026
2c13315
Fix critic metric reporting: explicit sum_loss_metrics flag for reduc…
justinvyu Mar 27, 2026
14ba02e
Remove reduce_metrics_across_minibatches, reuse reduce_metrics
justinvyu Mar 27, 2026
0cfc95b
Merge remote-tracking branch 'upstream/main' into token_mean_loss_red…
justinvyu Mar 27, 2026
717c3a7
add some comments about sum metrics
justinvyu Mar 27, 2026
661f5d8
add clarifying comments and rename loss_scale
justinvyu Mar 27, 2026
5cc95a1
no_grad for safety and make private
justinvyu Mar 27, 2026
ce8f6aa
remove outdated comments about loss reduction type in sapo tests
justinvyu Mar 27, 2026
1a60bb5
fix test
justinvyu Mar 27, 2026
c5feb83
fix test
justinvyu Mar 27, 2026
a599a4e
fix kl, entropy loss terms
justinvyu Mar 30, 2026
971be5f
revert sft
justinvyu Mar 30, 2026
308be63
finish reverting sft + normalize by num microbatches for the kl/entro…
justinvyu Mar 30, 2026
ad54440
revert sft case for megatron
justinvyu Mar 30, 2026
3829da5
fix tests
justinvyu Mar 30, 2026
a19ebca
add some comments about the kl/entropy terms
justinvyu Mar 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions examples/train/async/async_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from skyrl.train.trainer import RayPPOTrainer
from tqdm import tqdm
from skyrl.train.utils import Timer
from skyrl.backends.skyrl_train.utils.ppo_utils import normalize_advantages_dict
from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch
from skyrl.train.generators.base import GeneratorOutput
from skyrl.train.utils.trainer_utils import ResumeMode
Expand Down Expand Up @@ -146,9 +145,6 @@ async def _run_training(self, generation_buffer):
training_input.pop(key)
training_input.metadata.pop("uids")

if self.cfg.trainer.algorithm.advantage_batch_normalize:
training_input = normalize_advantages_dict(training_input)

if self.cfg.trainer.dump_data_batch:
# dump data to file
with Timer("dump_data_batch"):
Expand Down
4 changes: 0 additions & 4 deletions skyrl-agent/skyrl_agent/integrations/skyrl_train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from skyrl.train.utils import Timer
from skyrl.backends.skyrl_train.utils.ppo_utils import (
get_kl_controller,
normalize_advantages_dict,
)
from skyrl.train.utils.trainer_utils import (
validate_generator_output,
Expand Down Expand Up @@ -382,9 +381,6 @@ async def train(self):
training_input.pop(key)
training_input.metadata.pop("uids")

if self.cfg.trainer.algorithm.advantage_batch_normalize:
training_input = normalize_advantages_dict(training_input)

if self.cfg.trainer.dump_data_batch:
# dump data to file
with Timer("dump_data_batch"):
Expand Down
17 changes: 9 additions & 8 deletions skyrl/backends/skyrl_train/distributed/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ def get_rank(self) -> int:
"""Get current process rank"""
return dist.get_rank()

def all_reduce(self, data: DataT, op="mean") -> DataT:
"""Perform all_reduce across all processes"""
def all_reduce(self, data: DataT, op="mean", group=None) -> DataT:
"""Perform all_reduce across all processes (or within a process group)."""
assert op in ("mean", "max", "sum", "min")
if isinstance(data, dict):
return {k: self.all_reduce(v, op) for k, v in data.items()}
return {k: self.all_reduce(v, op, group=group) for k, v in data.items()}
else:
is_tensor = True
if not isinstance(data, torch.Tensor):
Expand All @@ -82,14 +82,15 @@ def all_reduce(self, data: DataT, op="mean") -> DataT:
if is_cpu_tensor:
data = data.to(torch.cuda.current_device())
if op == "mean":
data /= self.world_size
dist.all_reduce(data, op=dist.ReduceOp.SUM)
group_size = dist.get_world_size(group) if group is not None else self.world_size
data /= group_size
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=group)
elif op == "max":
dist.all_reduce(data, op=dist.ReduceOp.MAX)
dist.all_reduce(data, op=dist.ReduceOp.MAX, group=group)
elif op == "min":
dist.all_reduce(data, op=dist.ReduceOp.MIN)
dist.all_reduce(data, op=dist.ReduceOp.MIN, group=group)
elif op == "sum":
dist.all_reduce(data, op=dist.ReduceOp.SUM)
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=group)
if is_cpu_tensor:
data = data.cpu()
return data.item() if not is_tensor else data
Expand Down
133 changes: 57 additions & 76 deletions skyrl/backends/skyrl_train/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@
from collections import defaultdict
from enum import StrEnum
from functools import wraps
from typing import Callable, List, Literal, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import ray
import torch
from jaxtyping import Float
from loguru import logger

from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch
from skyrl.backends.skyrl_train.utils.off_policy_correction_utils import (
apply_off_policy_correction,
)
Expand Down Expand Up @@ -124,27 +123,6 @@ def compute_approx_kl(
return kld


@torch.no_grad()
def normalize_advantages_dict(data: TrainingInputBatch) -> TrainingInputBatch:
"""Normalizes the advantages in the data batch.

Expects:
- `["advantages"]`: Float[torch.Tensor, "batch_size seqlen"]
- `["response_mask"]`: Float[torch.Tensor, "batch_size seqlen"]
"""
advantages: Float[torch.Tensor, "batch_size seqlen"] = data["advantages"]
response_masks: Float[torch.Tensor, "batch_size seqlen"] = data["response_mask"]
num_actions: float = response_masks.sum()
# mean
mean: float = advantages.mean()
# std
std: float = ((advantages - mean).pow(2) * response_masks).sum()
rstd: float = (std / num_actions).clamp(min=1e-8).rsqrt()

data["advantages"] = (advantages - mean) * rstd
return data


def masked_var(values, mask, unbiased=True):
"""Compute variance of tensor with masked values."""
mean = masked_mean(values, mask)
Expand Down Expand Up @@ -558,12 +536,6 @@ def ppo_policy_loss(
rollout_logprobs: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, dict[str, float]]:
assert config.policy_loss_type in ["regular", "dual_clip"], "loss_type must be either 'regular' or 'dual_clip'"
loss_reduction = config.loss_reduction
assert loss_reduction in [
"token_mean",
"sequence_mean",
"seq_mean_token_sum_norm",
], "loss_reduction must be either 'token_mean', 'sequence_mean', or 'seq_mean_token_sum_norm'"

ratio = safe_exp_delta(log_probs - old_log_probs, clip=20.0, out_dtype=log_probs.dtype)
surr1 = ratio * advantages
Expand All @@ -584,7 +556,7 @@ def ppo_policy_loss(
)
loss_metrics.update(off_policy_metrics)

loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len)
loss = reduce_loss(loss, loss_mask)
return loss, loss_metrics


Expand Down Expand Up @@ -656,8 +628,7 @@ def gate_function(x, tau):
)
loss_metrics.update(off_policy_metrics)

# for SAPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean)
loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len)
loss = reduce_loss(loss, loss_mask)

return loss, loss_metrics

Expand Down Expand Up @@ -726,7 +697,7 @@ def gspo_policy_loss(
)
loss_metrics.update(off_policy_metrics)

loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len)
loss = reduce_loss(loss, loss_mask)

return loss, loss_metrics

Expand Down Expand Up @@ -763,7 +734,7 @@ def compute_policy_loss_cispo(
)
loss_metrics.update(off_policy_metrics)

loss = reduce_loss(loss, loss_mask, config.loss_reduction, config.max_seq_len)
loss = reduce_loss(loss, loss_mask)
return loss, loss_metrics


Expand Down Expand Up @@ -791,13 +762,6 @@ def rollout_is_policy_loss(
"""
assert rollout_logprobs is not None, "rollout_logprobs are required for rollout_is"

loss_reduction = config.loss_reduction
assert loss_reduction in [
"token_mean",
"sequence_mean",
"seq_mean_token_sum_norm",
], "loss_reduction must be either 'token_mean', 'sequence_mean', or 'seq_mean_token_sum_norm'"

ratio = safe_exp_delta(log_probs - rollout_logprobs, clip=20.0, out_dtype=log_probs.dtype)

in_range = (ratio > 1 - config.eps_clip_low) & (ratio < 1 + config.eps_clip_high)
Expand All @@ -812,7 +776,7 @@ def rollout_is_policy_loss(
)
loss_metrics.update(off_policy_metrics)

loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len)
loss = reduce_loss(loss, loss_mask)
return loss, loss_metrics


Expand Down Expand Up @@ -874,12 +838,7 @@ def compute_policy_loss_clip_cov(
# Apply correction mask to losses
pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr

pg_loss = reduce_loss(
loss=pg_losses,
loss_mask=loss_mask,
loss_reduction=config.loss_reduction,
max_seq_len=config.max_seq_len,
)
pg_loss = reduce_loss(loss=pg_losses, loss_mask=loss_mask)

return pg_loss, {"clip_ratio": clip_frac.item()}

Expand Down Expand Up @@ -933,12 +892,7 @@ def compute_policy_loss_kl_cov(
large_cov_idxs % advantages.shape[1],
]

pg_loss = reduce_loss(
loss=pg_losses,
loss_mask=loss_mask,
loss_reduction=config.loss_reduction,
max_seq_len=config.max_seq_len,
)
pg_loss = reduce_loss(loss=pg_losses, loss_mask=loss_mask)

# NOTE (sumanthrh): Since the pg clip ratio is not applicable for KL-COV so we just use 0.0
return pg_loss, {"clip_ratio": 0.0}
Expand Down Expand Up @@ -977,10 +931,7 @@ def cross_entropy_loss(
elementwise_loss = -log_probs

# Apply loss mask and sum (matching Tinker's SUM reduction semantics)
if loss_mask is not None:
loss = (elementwise_loss * loss_mask).sum()
else:
loss = elementwise_loss.sum()
loss = reduce_loss(elementwise_loss, loss_mask)

# No clipping in cross-entropy loss
return loss, {"clip_ratio": 0.0}
Expand Down Expand Up @@ -1039,30 +990,60 @@ def importance_sampling_loss(
def reduce_loss(
loss: torch.Tensor,
loss_mask: Optional[torch.Tensor],
loss_reduction: Literal["token_mean", "sequence_mean", "seq_mean_token_sum_norm"],
max_seq_len: Optional[int] = None,
) -> torch.Tensor:
return (loss * loss_mask).sum() if loss_mask is not None else loss.sum()


def apply_loss_reduction_to_advantages_minibatch(
advantages: torch.Tensor,
loss_mask: torch.Tensor,
loss_reduction: str,
micro_batch_size: int,
max_seq_len: int,
) -> torch.Tensor:
"""Scale advantages so that summing produces the desired loss reduction.

Args:
advantages: Advantage tensor of shape (minibatch_size, seq_len).
loss_mask: Mask of shape (minibatch_size, seq_len) indicating valid loss tokens.
loss_reduction: One of "token_mean", "token_mean_legacy", "sequence_mean", "seq_mean_token_sum_norm".
micro_batch_size: Number of sequences per micro-batch
max_seq_len: Maximum sequence length.

Returns:
Scaled advantages tensor.
"""
batch_size = advantages.shape[0]
normalized_advantages = torch.zeros_like(advantages)

# Option 1: token mean
if loss_reduction == "token_mean":
# sum over *all* valid tokens, divide by total valid-token count
loss = masked_mean(loss, loss_mask)
normalized_advantages = advantages / loss_mask.sum().clamp(min=1)

# Option 1b: legacy token-mean that normalizes per-microbatch then averages across microbatches.
elif loss_reduction == "token_mean_legacy":
num_micro_batches = batch_size // micro_batch_size
Comment thread
erictang000 marked this conversation as resolved.
for i in range(num_micro_batches):
start_idx = i * micro_batch_size
end_idx = (i + 1) * micro_batch_size
mb_advantages = advantages[start_idx:end_idx]
mb_loss_mask = loss_mask[start_idx:end_idx]
mb_advantages = mb_advantages / mb_loss_mask.sum().clamp(min=1)
mb_advantages /= num_micro_batches
normalized_advantages[start_idx:end_idx] = mb_advantages

# Option 2: sequence mean
elif loss_reduction == "sequence_mean":
# per-sequence token-mean (dim=-1), then batch-mean
loss = masked_mean(loss, loss_mask, dim=-1).mean()
normalized_advantages = advantages / (batch_size * loss_mask.sum(dim=-1, keepdim=True).clamp(min=1))

# Option 3: Dr. GRPO style loss reduction to avoid length bias by normalizing by a constant
elif loss_reduction == "seq_mean_token_sum_norm":
# per-sequence token-sum, normalized by the max sequence length, then batch mean
# this is the Dr. GRPO loss reduction to avoid length bias by normalizing by a constant
assert max_seq_len is not None, "max_seq_len must be provided for seq_mean_token_sum_norm loss reduction"
# NOTE: max_seq_len can be set explicitly via algorithm.max_seq_len, otherwise defaults to
# cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length
if loss_mask is not None:
seq_losses = torch.sum(loss * loss_mask, dim=-1) / max_seq_len
else:
# If no mask, assume all tokens are valid
seq_losses = torch.sum(loss, dim=-1) / max_seq_len
loss = torch.mean(seq_losses)
normalized_advantages = advantages / (batch_size * max_seq_len)

else:
raise ValueError(f"Invalid loss reduction type: {loss_reduction}")
return loss

return normalized_advantages


# NOTE (erictang000): below ported from verl
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ def loss_func(logits, data):
loss_mask = data["loss_mask"]
rollout_action_logprobs = data["rollout_action_logprobs"]
action_mask = data.get("action_mask")
num_microbatches = data.get("num_microbatches")

dp_size = mpu.get_data_parallel_world_size()
tp_grp = mpu.get_tensor_model_parallel_group()
tp_rank = mpu.get_tensor_model_parallel_rank()

Expand Down Expand Up @@ -310,7 +312,7 @@ def loss_func(logits, data):
)

metrics = {
"loss": loss.detach().item(),
"loss": loss.item(),
"response_length": num_actions,
"loss_fn_outputs": loss_fn_outputs,
}
Expand Down Expand Up @@ -340,7 +342,19 @@ def loss_func(logits, data):
kl_loss = torch.tensor(0.0)
kl_loss_term = kl_loss * loss_config.kl_loss_coef

loss = policy_loss + kl_loss_term - entropy_loss_term
# Policy losses are pre-scaled to achieve the correct loss_reduction
# when summing across the entire minibatch (see `apply_loss_reduction_to_advantages_minibatch`).
# Megatron divides loss by num_microbatches
# (https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/pipeline_parallel/schedules.py#L248)
# and the data parallel all-reduce averages gradients across dp_size
# (https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/distributed/distributed_data_parallel.py#L285)
# so we multiply by both factors to recover the correct sum reduction.
grad_sum_correction_factor = num_microbatches * dp_size

# NOTE: The KL and entropy loss terms are not pre-scaled,
# so we just average them across microbatches and DP workers.
loss = policy_loss * grad_sum_correction_factor + kl_loss_term - entropy_loss_term
unscaled_loss = loss / grad_sum_correction_factor
Comment on lines +352 to +357
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is similar to the FSDP case, except megatron already divides by num_microbatches and dp_size internally (so no need to divide by num_microbatches here).


# Build per-sequence loss_fn_outputs with logprobs.
batch_size = action_log_probs.shape[0]
Expand All @@ -363,7 +377,7 @@ def loss_func(logits, data):
)

metrics = {
"final_loss": loss.detach().item(),
"final_loss": unscaled_loss.detach().item(),
"policy_loss": policy_loss.detach().item(),
Comment on lines +380 to +381
Copy link
Copy Markdown
Contributor Author

@justinvyu justinvyu Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metrics fix 1: remove dp_size multiplier in reported metrics, since there's no average that we need to correct for, since reduce_microbatch_metrics and all_reduce_metrics both do sums for *_loss metrics.

"policy_entropy": entropy.detach().item(),
"policy_kl": kl_loss.detach().item(),
Expand Down
20 changes: 14 additions & 6 deletions skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,9 @@ def forward_backward(
}
)

for m_batch in micro_buffer:
m_batch["num_microbatches"] = len(micro_buffer)

if not micro_buffer:
return {}

Expand All @@ -718,9 +721,6 @@ def forward_backward(
if self.empty_cuda_cache:
torch.cuda.empty_cache()

# Track number of micro-batches for metrics
self._micro_batches_accumulated += len(micro_buffer)

# Aggregate metrics across micro-batches
all_loss_fn_outputs = [] # Handle separately from scalar metrics
for metrics in metrics_list:
Expand All @@ -730,10 +730,18 @@ def forward_backward(
for k, v in metrics.items():
all_metrics[k].append(v)

# Reduce and all-reduce metrics
status = reduce_metrics(dict(all_metrics))
# TODO: SFT path still averages metrics across microbatches and workers.
# This needs to be unified with the RL path which sums.
resolved_loss_name = loss_fn or self.cfg.algorithm.policy_loss_type
sum_loss_metrics = resolved_loss_name != "cross_entropy"

# Reduce across microbatches and all-reduce metrics across DP ranks
# (metrics should be identical within DP groups, i.e., across TP/PP/SP ranks)
# NOTE: Sum loss metrics because scaling is already applied at the advantage level
status = reduce_metrics(all_metrics, sum_loss_metrics=sum_loss_metrics)
status["policy_lr"] = self.optimizer.param_groups[0]["lr"]
status = all_reduce_metrics(status, self.strategy)
group = mpu.get_data_parallel_group(with_context_parallel=True)
status = all_reduce_metrics(status, self.strategy, group=group, sum_loss_metrics=sum_loss_metrics)

# Add loss_fn_outputs back (not reduced, kept as list)
if all_loss_fn_outputs:
Expand Down
Loading
Loading