Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions skyrl/train/step_wise_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""
Prefix-aware merging of step-wise trajectory turns for training.

When step_wise_trajectories=True, each turn is initially a separate sample.
We merge consecutive turns into fewer samples only when the next turn's prompt
token IDs have the previous full sequence (prompt + response) as an exact prefix.
Otherwise we keep them as separate samples (token-id prefix match only).

Merged samples avoid prompt/response overlap: ``prompt_token_ids`` is always the
first turn's observation in the merge group. Everything after that observation
(including later-turn observation *deltas* and all action tokens) lives in
``response_ids``, with delta tokens masked out of the loss (see loss_masks).
"""

from dataclasses import dataclass
from typing import List, Optional, Tuple


def _is_prefix(sequence: List[int], candidate: List[int]) -> bool:
"""Check if sequence is a prefix of candidate (exact token-id match)."""
if len(sequence) > len(candidate):
return False
return sequence == candidate[: len(sequence)]


@dataclass
class MergedStepWiseSample:
"""A single training sample after merging one or more step-wise turns."""

prompt_token_ids: List[int]
response_ids: List[int]
rewards: List[float]
loss_masks: List[int]
rollout_logprobs: Optional[List[float]] = None
is_last_step: bool = False


def merge_step_wise_turns_for_trajectory(
prompt_token_ids: List[List[int]],
response_ids: List[List[int]],
rewards: List[List[float]],
loss_masks: List[List[int]],
is_last_step: List[bool],
rollout_logprobs: Optional[List[List[float]]] = None,
) -> Tuple[List[MergedStepWiseSample], int]:
"""
Merge consecutive turns for a single trajectory when the next observation
has the previous full sequence (prompt + response) as an exact prefix.

No data leakage: prompt is the first turn's observation only; response is
resp1 + delta_ob2 + resp2 + ... (delta_ob tokens have zero loss mask).

Args:
prompt_token_ids: Per-turn prompt (observation) token IDs.
response_ids: Per-turn response (action) token IDs.
rewards: Per-turn per-token rewards (list of lists).
loss_masks: Per-turn loss masks (list of lists).
is_last_step: Per-turn flag True only on the final turn of the trajectory.
rollout_logprobs: Optional per-turn rollout logprobs (list of lists).

Returns:
(merged_samples, prefix_mismatch_count)
- merged_samples: List of merged training samples for this trajectory.
- prefix_mismatch_count: Number of times we did not merge due to prefix mismatch.
"""
n = len(prompt_token_ids)
assert n == len(response_ids) == len(rewards) == len(loss_masks) == len(is_last_step)
if rollout_logprobs is not None:
assert len(rollout_logprobs) == n

merged: List[MergedStepWiseSample] = []
prefix_mismatch_count = 0

# Full sequence so far (obs + response) for prefix check only
full_sequence: List[int] = []
# Initial observation only — prompt so that prompt + response = correct full sequence with no overlap
initial_prompt: List[int] = []
# Response stream: resp1 + delta_ob2 + resp2 + ... (delta_ob with zero loss so no duplicate tokens)
acc_response_ids: List[int] = []
acc_rewards: List[float] = []
acc_loss_masks: List[int] = []
acc_logprobs: List[float] = []
acc_is_last_step = False

def flush() -> None:
"""Emit current accumulated sample and reset."""
nonlocal full_sequence, initial_prompt, acc_response_ids, acc_rewards, acc_loss_masks, acc_logprobs, acc_is_last_step
if not initial_prompt and not acc_response_ids:
return
merged.append(
MergedStepWiseSample(
prompt_token_ids=list(initial_prompt),
response_ids=list(acc_response_ids),
rewards=list(acc_rewards),
loss_masks=list(acc_loss_masks),
rollout_logprobs=list(acc_logprobs) if (rollout_logprobs is not None) else None,
is_last_step=acc_is_last_step,
)
)
full_sequence = []
initial_prompt = []
acc_response_ids = []
acc_rewards = []
acc_loss_masks = []
acc_logprobs = []
acc_is_last_step = False

for i in range(n):
ob_tokens = prompt_token_ids[i]
ac_tokens = response_ids[i]
ac_rewards = rewards[i]
ac_masks = loss_masks[i]
ac_logprobs_i = rollout_logprobs[i] if rollout_logprobs is not None else None

if len(full_sequence) == 0:
delta_ob = ob_tokens
initial_prompt = list(delta_ob)
elif _is_prefix(full_sequence, ob_tokens):
delta_ob = ob_tokens[len(full_sequence) :]
# Interleave: delta_ob goes into response stream with zero loss so prompt+response = full sequence
acc_response_ids.extend(delta_ob)
acc_rewards.extend([0.0] * len(delta_ob))
acc_loss_masks.extend([0] * len(delta_ob))
if ac_logprobs_i is not None:
acc_logprobs.extend([0.0] * len(delta_ob))
else:
prefix_mismatch_count += 1
flush()
delta_ob = ob_tokens
initial_prompt = list(delta_ob)

full_sequence.extend(delta_ob)
full_sequence.extend(ac_tokens)
acc_response_ids.extend(ac_tokens)
acc_rewards.extend(ac_rewards)
acc_loss_masks.extend(ac_masks)
if ac_logprobs_i is not None:
acc_logprobs.extend(ac_logprobs_i)
# Last turn included in this merge group determines the flag (not OR: avoids stale True).
acc_is_last_step = is_last_step[i]

flush()
return merged, prefix_mismatch_count
98 changes: 95 additions & 3 deletions skyrl/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import shutil
from collections import defaultdict
from dataclasses import asdict
from itertools import groupby
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -52,6 +53,10 @@
GeneratorInput,
GeneratorInterface,
GeneratorOutput,
TrajectoryID,
)
from skyrl.train.step_wise_merge import (
merge_step_wise_turns_for_trajectory,
)
from skyrl.train.generators.utils import (
get_metrics_from_generator_output,
Expand Down Expand Up @@ -610,6 +615,83 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis
"rollout_expert_indices", None
)

num_samples_before_merge = len(prompt_ids)

if self.cfg.generator.step_wise_trajectories:
assert "trajectory_ids" in generator_output and "is_last_step" in generator_output
trajectory_ids_raw: List[TrajectoryID] = generator_output["trajectory_ids"]
is_last_step_list: List[bool] = generator_output["is_last_step"]

# Group consecutive indices by same trajectory (instance_id + repetition_id).
# groupby merges only adjacent runs with the same key; trajectory_ids_raw must list
# all turns of a trajectory in one contiguous block (no interleaving trajectories).
groups: List[Tuple[TrajectoryID, List[int]]] = []
for _, group in groupby(enumerate(trajectory_ids_raw), key=lambda x: x[1].to_string()):
indices = [i for i, _ in group]
if indices:
groups.append((trajectory_ids_raw[indices[0]], indices))

merged_prompt_ids: List[List[int]] = []
merged_response_ids: List[List[int]] = []
merged_rewards: List[List[float]] = []
merged_loss_masks: List[List[int]] = []
merged_logprobs: Optional[List[List[float]]] = [] if logprobs is not None else None
merged_is_last_step: List[bool] = []
merged_trajectory_ids: List[TrajectoryID] = []
total_prefix_mismatch = 0

for traj_id, indices in groups:
turn_prompts = [prompt_ids[j] for j in indices]
turn_responses = [response_ids[j] for j in indices]
turn_rewards = [rewards[j] for j in indices]
turn_masks = [loss_masks[j] for j in indices]
turn_is_last = [is_last_step_list[j] for j in indices]
turn_logprobs = [logprobs[j] for j in indices] if logprobs is not None else None

samples, mismatch_count = merge_step_wise_turns_for_trajectory(
prompt_token_ids=turn_prompts,
response_ids=turn_responses,
rewards=turn_rewards,
loss_masks=turn_masks,
is_last_step=turn_is_last,
rollout_logprobs=turn_logprobs,
)
total_prefix_mismatch += mismatch_count
for s in samples:
merged_prompt_ids.append(s.prompt_token_ids)
merged_response_ids.append(s.response_ids)
merged_rewards.append(s.rewards)
merged_loss_masks.append(s.loss_masks)
if merged_logprobs is not None:
merged_logprobs.append(s.rollout_logprobs)
merged_is_last_step.append(s.is_last_step)
merged_trajectory_ids.append(traj_id)

num_samples_after_merge = len(merged_prompt_ids)
prompt_ids = merged_prompt_ids
response_ids = merged_response_ids
rewards = merged_rewards
loss_masks = merged_loss_masks
logprobs = merged_logprobs
generator_output = {
**generator_output,
"prompt_token_ids": prompt_ids,
"response_ids": response_ids,
"rewards": rewards,
"loss_masks": loss_masks,
"rollout_logprobs": logprobs,
"is_last_step": merged_is_last_step,
"trajectory_ids": merged_trajectory_ids,
}
uids = [tid.instance_id for tid in merged_trajectory_ids]

self.all_metrics["trainer/stepwise_num_samples_before"] = num_samples_before_merge
self.all_metrics["trainer/stepwise_num_samples_after"] = num_samples_after_merge
self.all_metrics["trainer/stepwise_merge_ratio"] = (
num_samples_after_merge / num_samples_before_merge if num_samples_before_merge else 0.0
)
self.all_metrics["trainer/stepwise_prefix_mismatch_count"] = total_prefix_mismatch

(
sequences_tensor,
attention_masks_tensor,
Expand Down Expand Up @@ -673,9 +755,19 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis
training_input.metadata["trajectory_ids"] = [
trajectory_id.to_string() for trajectory_id in generator_output["trajectory_ids"]
]
training_input.metadata["avg_response_length"] = sum(
len(sample_response_ids) for sample_response_ids in response_ids
) / len(response_ids)
last_step_response_lens = [
len(sample_response_ids)
for sample_response_ids, is_last in zip(response_ids, generator_output["is_last_step"])
if is_last
]
num_last_steps = len(last_step_response_lens)
training_input.metadata["avg_response_length"] = (
sum(last_step_response_lens) / num_last_steps if num_last_steps else 0.0
)
else:
training_input.metadata["avg_response_length"] = sum(
len(sample_response_ids) for sample_response_ids in response_ids
) / len(response_ids)

logger.info(f"Number of sequences before padding: {len(training_input['sequences'])}")
training_input = self.pad_batch(training_input)
Expand Down
Loading
Loading