Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ jobs:
strategy:
fail-fast: false
matrix:
info: [{"num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_debug_rollout_then_train.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_opd_sglang.py"}]
info: [{"num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ppo_train_critic_only.py"}, {"num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_debug_rollout_then_train.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_opd_sglang.py"}]
defaults:
run:
working-directory: ${{ github.workspace }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pr-test.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
{'test_file': 'test_qwen3_30B_A3B_r3.py', 'num_gpus': 8, 'use_deepep': '1', 'use_fp8_rollout': '1', 'enable_eval': '0'},
{'test_file': 'test_qwen3_30B_A3B_r3.py', 'num_gpus': 8, 'enable_eval': '0'},
{'test_file': 'test_qwen3_4B_ppo.py', 'num_gpus': 8},
{'test_file': 'test_qwen3_4B_ppo_train_critic_only.py', 'num_gpus': 8},
{'test_file': 'test_moonlight_16B_A3B.py', 'num_gpus': 8},
{'test_file': 'test_moonlight_16B_A3B_r3.py', 'num_gpus': 8, 'enable_eval': '0'},
{'test_file': 'test_mimo_7B_mtp_only_grad.py', 'num_gpus': 8},
Expand Down
8 changes: 4 additions & 4 deletions slime/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ def init(
args, role
)

start_rollout_id = loaded_rollout_id + 1

if role == "critic":
if self.args.offload_train:
self.sleep()
return

start_rollout_id = loaded_rollout_id + 1
return start_rollout_id

self.weights_backuper = TensorBackuper.create(
source_getter=lambda: named_params_and_buffers(
Expand Down Expand Up @@ -378,7 +378,7 @@ def train_critic(self, rollout_id: int, rollout_data: RolloutBatch) -> None:
)
)

if rollout_id >= self.args.num_critic_only_steps:
if rollout_id >= self.args.num_critic_only_steps and not self.args.critic_train_only:
sync_actor_critic_data(self.args, rollout_data, self._actor_critic_groups)

compute_advantages_and_returns(self.args, rollout_data)
Expand Down
13 changes: 10 additions & 3 deletions slime/ray/placement_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,22 @@ def create_training_models(args, pgs, rollout_manager):
)
)

if args.use_critic:
critic_start_rollout_ids = ray.get(critic_init_handle)
if not args.critic_train_only:
actor_model.connect(critic_model)
else:
start_rollout_ids = critic_start_rollout_ids

assert len(set(start_rollout_ids)) == 1

if args.start_rollout_id is None:
args.start_rollout_id = start_rollout_ids[0]

actor_model.set_rollout_manager(rollout_manager)
if args.use_critic:
ray.get(critic_init_handle)
actor_model.connect(critic_model)
critic_model.set_rollout_manager(rollout_manager)
Comment on lines 171 to +173
Copy link

Copilot AI Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When --critic-train-only is set but --use-critic is not, critic_model is None and this will raise an AttributeError. Either validate that critic_train_only implies use_critic, or guard this block with if args.use_critic (and provide a clear error message if not).

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can always run

critic_model.set_rollout_manager(rollout_manager)

maybe change if args.critic_train_only to if use_critic and remove the critic.set_rollout_manager in actor_model.set_rollout_manager


actor_model.set_rollout_manager(rollout_manager)
if args.rollout_global_dataset:
ray.get(rollout_manager.load.remote(args.start_rollout_id - 1))

Expand Down
1 change: 1 addition & 0 deletions slime/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,7 @@ def add_algo_arguments(parser):
parser.add_argument("--critic-load", type=str, default=None, help="The checkpoint for critic model.")
parser.add_argument("--critic-save", type=str, default=None, help="The checkpoint for critic model.")
parser.add_argument("--critic-lr", type=float, default=None, help="The lr for critic model")
parser.add_argument("--critic-train-only", action="store_true", default=False, help="Only train critic")
parser.add_argument(
"--critic-lr-warmup-iters",
type=int,
Expand Down
134 changes: 134 additions & 0 deletions tests/test_qwen3_4B_ppo_train_critic_only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import os

import slime.utils.external_utils.command_utils as U


ENABLE_EVAL = bool(int(os.environ.get("SLIME_TEST_ENABLE_EVAL", "1")))
TIGHT_HOST_MEMORY = bool(int(os.environ.get("SLIME_TEST_TIGHT_HOST_MEMORY", "1")))

MODEL_NAME = "Qwen3-4B"
MODEL_TYPE = "qwen3-4B"
NUM_GPUS = 8


def prepare():
U.exec_command("mkdir -p /root/models /root/datasets")
U.exec_command("hf download Qwen/Qwen3-4B --local-dir /root/models/Qwen3-4B")
U.hf_download_dataset("zhuzilin/dapo-math-17k")
U.hf_download_dataset("zhuzilin/aime-2024")

U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS)


def execute():
ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/{MODEL_NAME}_torch_dist "

rollout_args = (
"--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl "
"--input-key prompt "
"--label-key label "
"--apply-chat-template "
"--rollout-shuffle "
"--rm-type deepscaler "
"--num-rollout 3 "
"--rollout-batch-size 8 "
"--n-samples-per-prompt 8 "
"--rollout-max-response-len 8192 "
"--rollout-temperature 0.8 "
"--global-batch-size 32 "
"--balance-data "
)

eval_args = (
f"{'--eval-interval 20 ' if ENABLE_EVAL else ''}"
"--eval-prompt-data aime24 /root/datasets/aime-2024/aime-2024.jsonl "
"--n-samples-per-eval-prompt 1 "
"--eval-max-response-len 16384 "
"--eval-top-k 1 "
)

perf_args = (
"--tensor-model-parallel-size 2 "
"--sequence-parallel "
"--pipeline-model-parallel-size 1 "
"--context-parallel-size 2 "
"--recompute-granularity full "
"--recompute-method uniform "
"--recompute-num-layers 1 "
"--use-dynamic-batch-size "
f"--max-tokens-per-gpu {2048 if TIGHT_HOST_MEMORY else 16384} "
)

ppo_args = (
"--advantage-estimator ppo "
f"{'' if TIGHT_HOST_MEMORY else '--use-kl-loss '}"
"--kl-loss-coef 0.00 "
"--kl-loss-type k1 "
"--kl-coef 0.00 "
"--entropy-coef 0.00 "
"--eps-clip 4e-4 "
"--critic-train-only "
"--normalize-advantages "
"--critic-lr 1e-5 "
)

optimizer_args = (
"--optimizer adam "
"--lr 1e-6 "
"--lr-decay-style constant "
"--weight-decay 0.1 "
"--adam-beta1 0.9 "
"--adam-beta2 0.98 "
)

sglang_args = (
"--rollout-num-gpus-per-engine 2 "
"--rollout-num-gpus 4 "
"--sglang-mem-fraction-static 0.8 "
"--sglang-max-running-requests 512 "
"--sglang-enable-metrics "
)

ci_args = "--ci-test "

misc_args = (
# default dropout in megatron is 0.1
"--attention-dropout 0.0 "
"--hidden-dropout 0.0 "
# should be good for model performance
"--accumulate-allreduce-grads-in-fp32 "
"--attention-softmax-in-fp32 "
# need to comment this when using model with MLA
"--attention-backend flash "
"--actor-num-nodes 0 "
"--actor-num-gpus-per-node 0 "
"--critic-num-nodes 1 "
"--critic-num-gpus-per-node 4 "
)

train_args = (
f"{ckpt_args} "
f"{rollout_args} "
f"{optimizer_args} "
f"{ppo_args} "
f"{U.get_default_wandb_args(__file__)} "
f"{perf_args} "
f"{eval_args} "
f"{sglang_args} "
f"{ci_args} "
f"{misc_args} "
)

U.execute_train(
train_args=train_args,
num_gpus_per_node=NUM_GPUS,
megatron_model_type=MODEL_TYPE,
)


if __name__ == "__main__":
# TODO also use typer
prepare()
for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"):
os.environ.pop(proxy_var, None)
execute()
Loading