Skip to content

GRPO training fails in multi-process mode #698

@yhtang

Description

@yhtang

Expected Behavior
GRPO training with Tunix in multi-process SPMD (e.g., 2 GPU nodes) should generate rollouts and compute advantages without failures.

Actual Behavior
Crash during rollout sampling when iterating a distributed jax.Array:

  • AssertionError at jax/_src/array.py:380: assert self.is_fully_replicated or self.is_fully_addressable
  • Triggered in tunix/generate/sampler.py while doing zip(out_tokens, lengths)

Steps to Reproduce the Problem

  1. Add jax.distributed.initialize() at the beginning of the GRPO example script.
  2. Run the GRPO demo
srun -u --label --ntasks=2 --ntasks-per-node=1 -c${SLURM_CPUS_ON_NODE} python <grpo-script>

Checklist

  • I have searched the existing issues for a similar bug report.
  • I have provided all the required information in the "Environment" section.
  • I have provided a minimal, reproducible example.

Would you like to help us fix it?

Metadata

Metadata

Assignees

Labels

type:bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions