Skip to content

[bug][train] Fix max_seq_len calculation#1303

Open
tamoghnokandar wants to merge 6 commits intoNovaSky-AI:mainfrom
tamoghnokandar:fix_max_seq_len
Open

[bug][train] Fix max_seq_len calculation#1303
tamoghnokandar wants to merge 6 commits intoNovaSky-AI:mainfrom
tamoghnokandar:fix_max_seq_len

Conversation

@tamoghnokandar
Copy link
Copy Markdown
Contributor

@tamoghnokandar tamoghnokandar commented Mar 10, 2026

Fixes #1154

Summary

This PR removes the implicit max_seq_len heuristic calculation and requires users to set it explicitly when using trainer.algorithm.loss_reduction=seq_mean_token_sum_norm.

Changes

  • Removed the automatic trainer.algorithm.max_seq_len default from SkyRLTrainConfig.__post_init__
  • Added config validation that requires trainer.algorithm.max_seq_len to be explicitly set when loss_reduction == "seq_mean_token_sum_norm"
  • Updated config comments and docs to reflect the new behavior and explain that max_seq_len must be chosen based on the user’s intended sequence-length normalization budget
  • Updated tests to cover:
    • max_seq_len remaining None by default
    • explicit max_seq_len values being preserved
    • validate_cfg() failing when seq_mean_token_sum_norm is used without max_seq_len
    • validate_cfg() continuing to allow token_mean and sequence_mean without max_seq_len
    • validate_cfg() passing when seq_mean_token_sum_norm is used with an explicit max_seq_len

Testing

  • Updated tests/train/test_config.py for the new behavior

Open with Devin

gemini-code-assist[bot]

This comment was marked as resolved.

tamoghnokandar and others added 2 commits March 9, 2026 18:04
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 3 new potential issues.

View 5 additional findings in Devin Review.

Open in Devin Review

Comment on lines +279 to +285
if cfg.trainer.algorithm.loss_reduction == "seq_mean_token_sum_norm":
if cfg.trainer.algorithm.max_seq_len is None:
raise ValueError(
"`trainer.algorithm.max_seq_len` must be set explicitly when "
"`trainer.algorithm.loss_reduction='seq_mean_token_sum_norm'`. "
"Choose the total sequence-length normalization constant for your setup; "
"this often matches the model context window / vLLM `max_model_len` when appropriate."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Breaking change: Dr. GRPO example script fails because auto-calculated max_seq_len fallback was removed

The PR removes the max_seq_len auto-calculation from SkyRLTrainConfig.__post_init__ (skyrl/train/config/config.py:713-722 on LEFT) and adds a hard assertion requiring it to be set explicitly when loss_reduction='seq_mean_token_sum_norm'. However, the official Dr. GRPO example script at examples/train/algorithms/drgrpo/run_drgrpo_gsm8k.sh:15,23 uses LOSS_REDUCTION="seq_mean_token_sum_norm" but never passes trainer.algorithm.max_seq_len. This script previously worked because __post_init__ auto-computed max_seq_len = max_input_length + max_generate_length. Now it will crash with an AssertionError at validation time.

Same issue in skyrl-agent example

skyrl-agent/examples/run_skyrl/run_skyrl_swe.sh:67 also sets trainer.algorithm.loss_reduction="seq_mean_token_sum_norm" without setting max_seq_len, so it will also fail.

Prompt for agents
Two example scripts need to be updated to explicitly pass trainer.algorithm.max_seq_len now that the auto-calculation fallback has been removed:

1. examples/train/algorithms/drgrpo/run_drgrpo_gsm8k.sh: Add a line like trainer.algorithm.max_seq_len=1536 (512 + 1024, matching max_prompt_length + max_generate_length from the script) to the uv run command.

2. skyrl-agent/examples/run_skyrl/run_skyrl_swe.sh: Add a line like trainer.algorithm.max_seq_len=40768 (8000 + 32768, matching max_prompt_length + max_generate_length from the script) to the uv run command.

Both scripts use loss_reduction=seq_mean_token_sum_norm and will now fail the new assertion at skyrl/train/utils/utils.py:279-285 without this fix.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tamoghnokandar this is important - can you grep for all usages of seq_mean_token_sum_norm in our example scripts and ensure that max_seq_len is explicitly passed in now (calculate it based on the generation and input lengths in the script).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LOSS_REDUCTION="seq_mean_token_sum_norm"

LOSS_REDUCTION="seq_mean_token_sum_norm"

trainer.algorithm.loss_reduction="seq_mean_token_sum_norm" \

Comment thread skyrl/train/utils/utils.py
Comment thread skyrl/train/config/config.py
@SumanthRH SumanthRH self-assigned this Mar 19, 2026
@SumanthRH SumanthRH self-requested a review March 19, 2026 19:20
Copy link
Copy Markdown
Member

@SumanthRH SumanthRH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tamoghnokandar can you merge the latest changes from main? We've had some important updates, especially bf243b8

@tamoghnokandar
Copy link
Copy Markdown
Contributor Author

Done!

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also add a note in the loss_type docstring that max_seq_len is required for seq_man_token_sum_norm now?

https://github.com/tamoghnokandar/SkyRL/blob/9c44257041101fe52893ad554bff026d49527b58/skyrl/train/config/config.py#L351

devin-ai-integration[bot]

This comment was marked as resolved.

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 new potential issue.

View 6 additional findings in Devin Review.

Open in Devin Review

Comment on lines 719 to 721
if self.trainer.algorithm.max_seq_len is None:
# NOTE (erictang000): this is the max sequence length including the prompt, since max response length
# per batch can be variable based on the prompt length. This is used to normalize the loss for
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Auto-calculation of max_seq_len in __post_init__ not removed, making the new validation dead code

The PR's stated intent is to require users to explicitly set max_seq_len when using seq_mean_token_sum_norm loss reduction. The validation in skyrl/train/utils/utils.py:286-293 checks if cfg.trainer.algorithm.max_seq_len is None and raises ValueError. However, SkyRLTrainConfig.__post_init__ at skyrl/train/config/config.py:719-728 still auto-calculates max_seq_len whenever it's None:

if self.trainer.algorithm.max_seq_len is None:
    self.trainer.algorithm.max_seq_len = (
        self.generator.max_input_length + self.generator.sampling_params.max_generate_length
    )

Since __post_init__ runs at config construction time (before validate_cfg is called), max_seq_len will never be None when the validation runs. This means: (1) the validate_cfg check is dead code in normal usage, (2) users who don't set max_seq_len will silently get auto-calculated values instead of the intended error, and (3) the test test_max_seq_len_defaults_to_none_when_not_set will fail because __post_init__ populates the value. The auto-calculation block should be removed from __post_init__.

(Refers to lines 719-728)

Prompt for agents
The auto-calculation of max_seq_len in SkyRLTrainConfig.__post_init__ (config.py lines 719-728) must be removed to align with the PR's intent. The PR adds a validation check in validate_cfg (utils.py lines 286-293) that raises ValueError when max_seq_len is None and loss_reduction is seq_mean_token_sum_norm. But the __post_init__ always fills in max_seq_len before validate_cfg is ever called, so the validation never triggers.

To fix: remove the entire if-block at lines 719-728 in skyrl/train/config/config.py. This will allow max_seq_len to remain None when not explicitly set, and the validate_cfg check will correctly fire when a user uses seq_mean_token_sum_norm without setting max_seq_len.

Note: after removing the auto-calculation, code that passes max_seq_len to preprocess_data (trainer.py:631) and apply_loss_reduction_to_advantages_minibatch (trainer.py:1084) may receive None. The preprocess_data function already handles Optional[int] for max_seq_len (it just logs a warning). For apply_loss_reduction_to_advantages_minibatch, max_seq_len is only used in the seq_mean_token_sum_norm branch, which will be guarded by the new validation. But you should verify no other code paths rely on max_seq_len being non-None for non-seq_mean_token_sum_norm reductions.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Copy link
Copy Markdown
Contributor Author

@tamoghnokandar tamoghnokandar Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SumanthRH Should I remove this? For seq_mean_token_sum_norm loss we are setting it explicitly anyways. This is not required for other loss functions right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[bug][train] max_seq_len auto-calculation is incorrect for multi-turn in seq_mean_token_sum_norm

2 participants