[DO NOT REVIEW] [Pallas] Add TPU nightly benchmark workflow and runner#1913
Draft
[DO NOT REVIEW] [Pallas] Add TPU nightly benchmark workflow and runner#1913
Conversation
Add a nightly CI workflow that runs Helion examples with autotuning on TPU, with results published to pytorch benchmark hub.
- gnupg -> gpg (package available on runner) - Use correct secret name (torchtpu-read-key) and repo (google-pytorch/torch_tpu) - Update jax/jaxlib to 0.9.2 matching test.yml
- Add 600s per-kernel timeout using multiprocessing to handle stuck autotuning (native C++ calls can't be interrupted by Python signals) - Set HELION_AUTOTUNE_EFFORT=quick in CI for faster autotuning (30 initial population, 5 generations vs 100/20 for full) - Timeout configurable via HELION_BENCHMARK_KERNEL_TIMEOUT env var
The pytorch/test-infra gather-* actions require pip and nvidia-ml-py, which don't work in a uv venv on TPU runners. Remove the upload job and gather-* steps; keep only the artifact upload for now.
… generations - Add --num-shapes CLI flag to control how many shapes per kernel (default: all) - Restore full shape lists but use --num-shapes 1 in CI to avoid multiplied autotuning time - Increase per-kernel timeout from 600s to 1200s (quick autotuning on v7 takes ~10min) - Set HELION_AUTOTUNE_MAX_GENERATIONS=2 to further limit autotuning time - Don't fail the job on partial kernel failures (report results for what passed)
… timeout The benchmark runner was using multiprocessing.Process (fork) for per-kernel timeouts. On Linux, forking after TPU/JAX initialization causes deadlocks because JAX's internal threads and locks don't survive fork correctly. This caused every kernel to hang for the full timeout (1200s) on CI. Replace with signal.SIGALRM which runs everything in one process, avoiding the fork-after-init issue entirely.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Caution
The
pushtrigger inbenchmark_tpu_nightly.ymlis temporary for CI testing — remove before merging.Summary
Add a nightly CI workflow that runs Helion examples with autotuning on TPU, with results published to pytorch benchmark hub. This mirrors the GPU benchmark infrastructure (
benchmark_nightly.yml+benchmarks/run.py) but uses a standalone runner since TritonBench has hard dependencies ontritonand CUDA at import/install/runtime levels.New files
benchmarks/run_tpu.py— TPU benchmark runner--kernel/--op(comma-separated),--output(JSON),--list-kernelsexp,add,softmax_two_pass,welford,layer_norm.github/workflows/benchmark_tpu.yml— Reusable benchmark workflowlinux.google.tpuv7x.1HELION_ASSERT_CACHE_HIT=1verify + record.github/workflows/benchmark_tpu_nightly.yml— Nightly triggerworkflow_dispatchwithkernelsinput for manual runsModified files
helion/_testing.py—run_example()return type changed fromNonetodict[str, float](mapping implementation names to their benchmark times in ms), enablingrun_tpu.pyto capture and display per-shape speedup without duplicating timing logicExcluded kernels (with reasons)
rms_norm:InductorLoweringErrorintorch.meanreduction codegen forfori_loop/emit_pipelinegeglu/swiglu: autotuning takes >15min per kernel (large shape 8x2048x4096), many configs fail to compilelow_mem_dropout: ~37% element accuracy mismatch on all configs exceptblock_sizes=[128]Example output
Test plan
run_tpu.pyon TPU pod with all 5 default kernelsrun_example()return value is correctly consumedpushtrigger on this branchpushtrigger before merging