-
Notifications
You must be signed in to change notification settings - Fork 302
[BREAKING][skyrl-train] Implement loss reduction via advantage normalization and fix token_mean reduction strategy
#1296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
589c150
333f31a
aaaba4c
a121360
15de89a
e3842c3
13bfe80
e76bece
0192e8e
4ee0b31
c8f06cc
2c13315
14ba02e
0cfc95b
717c3a7
661f5d8
5cc95a1
ce8f6aa
1a60bb5
c5feb83
a599a4e
971be5f
308be63
ad54440
3829da5
a19ebca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -246,7 +246,9 @@ def loss_func(logits, data): | |
| loss_mask = data["loss_mask"] | ||
| rollout_action_logprobs = data["rollout_action_logprobs"] | ||
| action_mask = data.get("action_mask") | ||
| num_microbatches = data.get("num_microbatches") | ||
|
|
||
| dp_size = mpu.get_data_parallel_world_size() | ||
| tp_grp = mpu.get_tensor_model_parallel_group() | ||
| tp_rank = mpu.get_tensor_model_parallel_rank() | ||
|
|
||
|
|
@@ -310,7 +312,7 @@ def loss_func(logits, data): | |
| ) | ||
|
|
||
| metrics = { | ||
| "loss": loss.detach().item(), | ||
| "loss": loss.item(), | ||
| "response_length": num_actions, | ||
| "loss_fn_outputs": loss_fn_outputs, | ||
| } | ||
|
|
@@ -340,7 +342,19 @@ def loss_func(logits, data): | |
| kl_loss = torch.tensor(0.0) | ||
| kl_loss_term = kl_loss * loss_config.kl_loss_coef | ||
|
|
||
| loss = policy_loss + kl_loss_term - entropy_loss_term | ||
| # Policy losses are pre-scaled to achieve the correct loss_reduction | ||
| # when summing across the entire minibatch (see `apply_loss_reduction_to_advantages_minibatch`). | ||
| # Megatron divides loss by num_microbatches | ||
| # (https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/pipeline_parallel/schedules.py#L248) | ||
| # and the data parallel all-reduce averages gradients across dp_size | ||
| # (https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/distributed/distributed_data_parallel.py#L285) | ||
| # so we multiply by both factors to recover the correct sum reduction. | ||
| grad_sum_correction_factor = num_microbatches * dp_size | ||
|
|
||
| # NOTE: The KL and entropy loss terms are not pre-scaled, | ||
| # so we just average them across microbatches and DP workers. | ||
| loss = policy_loss * grad_sum_correction_factor + kl_loss_term - entropy_loss_term | ||
| unscaled_loss = loss / grad_sum_correction_factor | ||
|
Comment on lines
+352
to
+357
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is similar to the FSDP case, except megatron already divides by num_microbatches and dp_size internally (so no need to divide by num_microbatches here). |
||
|
|
||
| # Build per-sequence loss_fn_outputs with logprobs. | ||
| batch_size = action_log_probs.shape[0] | ||
|
|
@@ -363,7 +377,7 @@ def loss_func(logits, data): | |
| ) | ||
|
|
||
| metrics = { | ||
| "final_loss": loss.detach().item(), | ||
| "final_loss": unscaled_loss.detach().item(), | ||
| "policy_loss": policy_loss.detach().item(), | ||
|
Comment on lines
+380
to
+381
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Metrics fix 1: remove |
||
| "policy_entropy": entropy.detach().item(), | ||
| "policy_kl": kl_loss.detach().item(), | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.