Replies: 2 comments
-
|
Thanks for driving this RFC --- I’m strongly aligned with the goal of broadening accessibility beyond TPU. Some specific points (sharing my own views only):
Thanks again for spearheading this effort. |
Beta Was this translation helpful? Give feedback.
-
|
Thank you @yhtang for the valuable input and +1 on everything! For bullet point 2, I have a separate RFC because McJax can be device agnostic and both TPU and GPU can benefits from that. Let's also keep in mind with OSS Pathways on track, both single controller Jax and McJax provides unique advantages in different areas. Let's make sure Tunix achieves best performance on both setups. For bullet point 3, I'm thinking to have an abstract attention adapter layer to provide the unified interface to the model, and we can have various implementations or imports. For bullet point 4, absolutely! Blackwell > Hopper > Ampere is the way to go. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Status: Draft for community feedback
Authors: Lance Wang
Tunix focuses on TPU-first training/inference for SFT and RL (DPO/PPO/GRPO/GSPO/...). Community users and contributors have expressed strong interest in running Tunix on GPUs (A100/H100, 4090/4080, L4, etc.). GPU support broadens accessibility, enables on‑prem/smaller‑scale runs, and increases contributor velocity.
Goals
Modularized design: clear boundary between GPU and TPU
Parity path: Run the full Tunix SFT + RL stack on single- and multi‑GPU via XLA:GPU (PJRT), starting from unit tests to the notebook/script examples: https://github.com/google/tunix/tree/main/examples
Performance-minded: Competitive throughput/latency using bf16/fp16, Flash‑Attention, and fused optimizers.
Simple install: Docker and Conda instructions; reproducible envs for CUDA 12.x/13.x + cuDNN 9 or ROCm.
CI coverage: Smoke tests on single‑GPU; nightly correctness on multi‑GPU via self-hosted runners.
Docs: Clear “Getting Started on GPU” & troubleshooting.
Non‑Goals
Non‑JAX training backends (e.g., PyTorch) for core Tunix trainers.
Perfect performance parity with TPU on day one.
US1 (Single‑GPU dev): A researcher with a 4090 runs SFT and minimal GRPO locally.
US2 (Multi‑GPU node): A lab with 8xH100 trains 1B–30B using FSDP+TP.
US3 (Multi‑host): A cluster with multiple H100 nodes runs distributed RL with NCCL over InfiniBand/RoCE.
US4 (Eval/Serve): Run RL learning and eval loops on GPU without TPU dependencies.
R1: Support CUDA 12/13; cuDNN 9.x.
R2: Support bf16 on Ampere/Hopper; fp16 fallback on consumer GPUs if needed.
R3: PJRT runtime on XLA:GPU; NCCL collectives for dp/fsdp/tp.
R4: Flash‑Attention path (jax-labs/pallas or vendor kernels) for attention-heavy models.
R5: CI smoke tests on single GPU; perf sanity benchmarks; correctness parity on small models.
Beta Was this translation helpful? Give feedback.
All reactions