Skip to content

[DO NOT REVIEW] [Pallas] Add TPU nightly benchmark workflow and runner#1913

Draft
norx1991 wants to merge 8 commits intomainfrom
yifeixu/tpu-nightly-benchmark
Draft

[DO NOT REVIEW] [Pallas] Add TPU nightly benchmark workflow and runner#1913
norx1991 wants to merge 8 commits intomainfrom
yifeixu/tpu-nightly-benchmark

Conversation

@norx1991
Copy link
Copy Markdown
Contributor

@norx1991 norx1991 commented Apr 1, 2026

Caution

The push trigger in benchmark_tpu_nightly.yml is temporary for CI testing — remove before merging.

Summary

Add a nightly CI workflow that runs Helion examples with autotuning on TPU, with results published to pytorch benchmark hub. This mirrors the GPU benchmark infrastructure (benchmark_nightly.yml + benchmarks/run.py) but uses a standalone runner since TritonBench has hard dependencies on triton and CUDA at import/install/runtime levels.

New files

  • benchmarks/run_tpu.py — TPU benchmark runner

    • CLI: --kernel/--op (comma-separated), --output (JSON), --list-kernels
    • Multi-shape benchmarking: each kernel tested at multiple input sizes (like TritonBench does for GPU)
    • Per-shape accuracy check + timing vs torch baseline with speedup display
    • JSON output in pytorch benchmark hub format
    • 5 reliable kernels: exp, add, softmax_two_pass, welford, layer_norm
  • .github/workflows/benchmark_tpu.yml — Reusable benchmark workflow

    • Runner: linux.google.tpuv7x.1
    • Setup: PyTorch CPU nightly, JAX/Pallas, builds torch_tpu from pinned commit
    • Two-pass pattern: autotune → sleep 1min → HELION_ASSERT_CACHE_HIT=1 verify + record
    • Upload to pytorch benchmark hub
  • .github/workflows/benchmark_tpu_nightly.yml — Nightly trigger

    • Cron: daily at 2 AM PST (10 AM UTC)
    • workflow_dispatch with kernels input for manual runs

Modified files

  • helion/_testing.pyrun_example() return type changed from None to dict[str, float] (mapping implementation names to their benchmark times in ms), enabling run_tpu.py to capture and display per-shape speedup without duplicating timing logic

Excluded kernels (with reasons)

  • rms_norm: InductorLoweringError in torch.mean reduction codegen for fori_loop/emit_pipeline
  • geglu/swiglu: autotuning takes >15min per kernel (large shape 8x2048x4096), many configs fail to compile
  • low_mem_dropout: ~37% element accuracy mismatch on all configs except block_sizes=[128]

Example output

===========================================================================
Summary
===========================================================================
Kernel                 Shape            Status   Helion (ms)    Torch (ms)     Speedup
---------------------------------------------------------------------------
exp                    [1024]           PASS     0.3410         0.2676         0.78x
exp                    [4096]           PASS     0.2829         0.2320         0.82x
exp                    [16384]          PASS     0.2924         0.2276         0.78x
exp                    [65536]          PASS     0.2744         0.2173         0.79x
exp                    [262144]         PASS     0.3246         0.2355         0.73x
exp                    [1048576]        PASS     0.3076         0.2134         0.69x
add                    [128,128]        PASS     0.2688         0.2087         0.78x
add                    [256,256]        PASS     0.2727         0.2149         0.79x
add                    [512,512]        PASS     0.3202         0.2259         0.71x
add                    [1024,1024]      PASS     0.3165         0.2225         0.70x
add                    [2048,2048]      PASS     0.3604         0.2440         0.68x
softmax_two_pass       [1024,256]       PASS     0.3279         0.2375         0.72x
softmax_two_pass       [1024,512]       PASS     0.3701         0.2371         0.64x
softmax_two_pass       [1024,1024]      PASS     0.3610         0.2391         0.66x
softmax_two_pass       [1024,2048]      PASS     0.3112         0.2195         0.71x
softmax_two_pass       [1024,4096]      PASS     0.3373         0.2317         0.69x
---------------------------------------------------------------------------
Total: 3/3 passed
===========================================================================

Test plan

  • Tested run_tpu.py on TPU pod with all 5 default kernels
  • Verified per-shape accuracy and speedup display
  • Verified run_example() return value is correctly consumed
  • CI run triggered via temporary push trigger on this branch
  • Remove push trigger before merging

Add a nightly CI workflow that runs Helion examples with autotuning on TPU,
with results published to pytorch benchmark hub.
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 1, 2026
@norx1991 norx1991 changed the title Add TPU nightly benchmark workflow and runner [DO NOT REVIEW] [Pallas] Add TPU nightly benchmark workflow and runner Apr 1, 2026
norx1991 added 7 commits April 1, 2026 13:41
- gnupg -> gpg (package available on runner)
- Use correct secret name (torchtpu-read-key) and repo (google-pytorch/torch_tpu)
- Update jax/jaxlib to 0.9.2 matching test.yml
- Add 600s per-kernel timeout using multiprocessing to handle stuck
  autotuning (native C++ calls can't be interrupted by Python signals)
- Set HELION_AUTOTUNE_EFFORT=quick in CI for faster autotuning
  (30 initial population, 5 generations vs 100/20 for full)
- Timeout configurable via HELION_BENCHMARK_KERNEL_TIMEOUT env var
The pytorch/test-infra gather-* actions require pip and nvidia-ml-py,
which don't work in a uv venv on TPU runners. Remove the upload job
and gather-* steps; keep only the artifact upload for now.
… generations

- Add --num-shapes CLI flag to control how many shapes per kernel (default: all)
- Restore full shape lists but use --num-shapes 1 in CI to avoid multiplied autotuning time
- Increase per-kernel timeout from 600s to 1200s (quick autotuning on v7 takes ~10min)
- Set HELION_AUTOTUNE_MAX_GENERATIONS=2 to further limit autotuning time
- Don't fail the job on partial kernel failures (report results for what passed)
… timeout

The benchmark runner was using multiprocessing.Process (fork) for per-kernel
timeouts. On Linux, forking after TPU/JAX initialization causes deadlocks
because JAX's internal threads and locks don't survive fork correctly. This
caused every kernel to hang for the full timeout (1200s) on CI.

Replace with signal.SIGALRM which runs everything in one process, avoiding
the fork-after-init issue entirely.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants