-
Notifications
You must be signed in to change notification settings - Fork 150
Open
Labels
type:bugSomething isn't workingSomething isn't working
Description
Expected Behavior
GRPO training with Tunix in multi-process SPMD (e.g., 2 GPU nodes) should generate rollouts and compute advantages without failures.
Actual Behavior
Crash during rollout sampling when iterating a distributed jax.Array:
AssertionErroratjax/_src/array.py:380:assert self.is_fully_replicated or self.is_fully_addressable- Triggered in
tunix/generate/sampler.pywhile doingzip(out_tokens, lengths)
Steps to Reproduce the Problem
- Add
jax.distributed.initialize()at the beginning of the GRPO example script. - Run the GRPO demo
srun -u --label --ntasks=2 --ntasks-per-node=1 -c${SLURM_CPUS_ON_NODE} python <grpo-script>
Checklist
- I have searched the existing issues for a similar bug report.
- I have provided all the required information in the "Environment" section.
- I have provided a minimal, reproducible example.
Would you like to help us fix it?
Metadata
Metadata
Assignees
Labels
type:bugSomething isn't workingSomething isn't working