[bug][train] Fix max_seq_len calculation#1303
[bug][train] Fix max_seq_len calculation#1303tamoghnokandar wants to merge 6 commits intoNovaSky-AI:mainfrom
Conversation
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
| if cfg.trainer.algorithm.loss_reduction == "seq_mean_token_sum_norm": | ||
| if cfg.trainer.algorithm.max_seq_len is None: | ||
| raise ValueError( | ||
| "`trainer.algorithm.max_seq_len` must be set explicitly when " | ||
| "`trainer.algorithm.loss_reduction='seq_mean_token_sum_norm'`. " | ||
| "Choose the total sequence-length normalization constant for your setup; " | ||
| "this often matches the model context window / vLLM `max_model_len` when appropriate." |
There was a problem hiding this comment.
🔴 Breaking change: Dr. GRPO example script fails because auto-calculated max_seq_len fallback was removed
The PR removes the max_seq_len auto-calculation from SkyRLTrainConfig.__post_init__ (skyrl/train/config/config.py:713-722 on LEFT) and adds a hard assertion requiring it to be set explicitly when loss_reduction='seq_mean_token_sum_norm'. However, the official Dr. GRPO example script at examples/train/algorithms/drgrpo/run_drgrpo_gsm8k.sh:15,23 uses LOSS_REDUCTION="seq_mean_token_sum_norm" but never passes trainer.algorithm.max_seq_len. This script previously worked because __post_init__ auto-computed max_seq_len = max_input_length + max_generate_length. Now it will crash with an AssertionError at validation time.
Same issue in skyrl-agent example
skyrl-agent/examples/run_skyrl/run_skyrl_swe.sh:67 also sets trainer.algorithm.loss_reduction="seq_mean_token_sum_norm" without setting max_seq_len, so it will also fail.
Prompt for agents
Two example scripts need to be updated to explicitly pass trainer.algorithm.max_seq_len now that the auto-calculation fallback has been removed:
1. examples/train/algorithms/drgrpo/run_drgrpo_gsm8k.sh: Add a line like trainer.algorithm.max_seq_len=1536 (512 + 1024, matching max_prompt_length + max_generate_length from the script) to the uv run command.
2. skyrl-agent/examples/run_skyrl/run_skyrl_swe.sh: Add a line like trainer.algorithm.max_seq_len=40768 (8000 + 32768, matching max_prompt_length + max_generate_length from the script) to the uv run command.
Both scripts use loss_reduction=seq_mean_token_sum_norm and will now fail the new assertion at skyrl/train/utils/utils.py:279-285 without this fix.
Was this helpful? React with 👍 or 👎 to provide feedback.
There was a problem hiding this comment.
@tamoghnokandar this is important - can you grep for all usages of seq_mean_token_sum_norm in our example scripts and ensure that max_seq_len is explicitly passed in now (calculate it based on the generation and input lengths in the script).
There was a problem hiding this comment.
SumanthRH
left a comment
There was a problem hiding this comment.
@tamoghnokandar can you merge the latest changes from main? We've had some important updates, especially bf243b8
|
Done! |
There was a problem hiding this comment.
can you also add a note in the loss_type docstring that max_seq_len is required for seq_man_token_sum_norm now?
| if self.trainer.algorithm.max_seq_len is None: | ||
| # NOTE (erictang000): this is the max sequence length including the prompt, since max response length | ||
| # per batch can be variable based on the prompt length. This is used to normalize the loss for |
There was a problem hiding this comment.
🔴 Auto-calculation of max_seq_len in __post_init__ not removed, making the new validation dead code
The PR's stated intent is to require users to explicitly set max_seq_len when using seq_mean_token_sum_norm loss reduction. The validation in skyrl/train/utils/utils.py:286-293 checks if cfg.trainer.algorithm.max_seq_len is None and raises ValueError. However, SkyRLTrainConfig.__post_init__ at skyrl/train/config/config.py:719-728 still auto-calculates max_seq_len whenever it's None:
if self.trainer.algorithm.max_seq_len is None:
self.trainer.algorithm.max_seq_len = (
self.generator.max_input_length + self.generator.sampling_params.max_generate_length
)Since __post_init__ runs at config construction time (before validate_cfg is called), max_seq_len will never be None when the validation runs. This means: (1) the validate_cfg check is dead code in normal usage, (2) users who don't set max_seq_len will silently get auto-calculated values instead of the intended error, and (3) the test test_max_seq_len_defaults_to_none_when_not_set will fail because __post_init__ populates the value. The auto-calculation block should be removed from __post_init__.
(Refers to lines 719-728)
Prompt for agents
The auto-calculation of max_seq_len in SkyRLTrainConfig.__post_init__ (config.py lines 719-728) must be removed to align with the PR's intent. The PR adds a validation check in validate_cfg (utils.py lines 286-293) that raises ValueError when max_seq_len is None and loss_reduction is seq_mean_token_sum_norm. But the __post_init__ always fills in max_seq_len before validate_cfg is ever called, so the validation never triggers.
To fix: remove the entire if-block at lines 719-728 in skyrl/train/config/config.py. This will allow max_seq_len to remain None when not explicitly set, and the validate_cfg check will correctly fire when a user uses seq_mean_token_sum_norm without setting max_seq_len.
Note: after removing the auto-calculation, code that passes max_seq_len to preprocess_data (trainer.py:631) and apply_loss_reduction_to_advantages_minibatch (trainer.py:1084) may receive None. The preprocess_data function already handles Optional[int] for max_seq_len (it just logs a warning). For apply_loss_reduction_to_advantages_minibatch, max_seq_len is only used in the seq_mean_token_sum_norm branch, which will be guarded by the new validation. But you should verify no other code paths rely on max_seq_len being non-None for non-seq_mean_token_sum_norm reductions.
Was this helpful? React with 👍 or 👎 to provide feedback.
There was a problem hiding this comment.
@SumanthRH Should I remove this? For seq_mean_token_sum_norm loss we are setting it explicitly anyways. This is not required for other loss functions right?
Fixes #1154
Summary
This PR removes the implicit
max_seq_lenheuristic calculation and requires users to set it explicitly when usingtrainer.algorithm.loss_reduction=seq_mean_token_sum_norm.Changes
trainer.algorithm.max_seq_lendefault fromSkyRLTrainConfig.__post_init__trainer.algorithm.max_seq_lento be explicitly set whenloss_reduction == "seq_mean_token_sum_norm"max_seq_lenmust be chosen based on the user’s intended sequence-length normalization budgetmax_seq_lenremainingNoneby defaultmax_seq_lenvalues being preservedvalidate_cfg()failing whenseq_mean_token_sum_normis used withoutmax_seq_lenvalidate_cfg()continuing to allowtoken_meanandsequence_meanwithoutmax_seq_lenvalidate_cfg()passing whenseq_mean_token_sum_normis used with an explicitmax_seq_lenTesting
tests/train/test_config.pyfor the new behavior