Skip to content

Conversation

@huy209vn
Copy link
Contributor

@huy209vn huy209vn commented Oct 21, 2025

Implements RMS Normalization with support for bias, configurable epsilon, and runtime vectorization.
Includes fully verified CUDA benchmark and backend tests across CPU, WGPU, HIP.

Status

✅ Kernel: finalized, verified across dtypes (F32, F16, BF16)
✅ Benchmarks: refactored, stable, true GPU-time measurement
✅ Tests: cross-backend validation vs CPU reference

Kernel Notes

  • Two variants:

    • streaming — minimal register pressure (default)

    • smem — kept for future optimization, currently disabled

  • Vectorization: vec=8 (F32), vec=4 (F16/BF16)

  • Enforces ≥512 threads for occupancy

  • Avoids register spills via MAX_LINES_PER_THREAD=4

  • Uses fast reciprocal sqrt when supported

Benchmark Results (RTX 3090, CUDA)

Shape: [8, 1024, 4096] (33.55M elements)

Type Time (ms) Throughput (Gelem/s)
F32 0.419 80.04
F16 0.219 153.54
BF16 0.219 153.41

Shape: [16, 1024, 4096] (67.11M elements)

Type Time (ms) Throughput (Gelem/s)
F32 0.838 80.05
F16 0.430 155.97
BF16 0.425 157.85

Shape: [32, 2048, 4096] (268.44M elements)

Type Time (ms) Throughput (Gelem/s)
F32 3.402 78.90
F16 1.688 159.05
BF16 1.658 161.94

Shape: [64, 4096, 4096] (1073.74M elements)

Type Time (ms) Throughput (Gelem/s)
F32 13.309 80.68
F16 6.438 166.78
BF16 6.479 165.73

✅ No regressions.
✅ All dtypes saturate memory bandwidth as expected.

Commands

cargo bench -p cubecl --bench rms_norm --features "cuda,random" -- --nocapture

Env Overrides

CUBECL_RMS_LOG=1 # verbose logging CUBECL_RMS_VEC=4|8|16 # vectorization override CUBECL_RMS_VARIANT=stream # force variant

@huy209vn huy209vn marked this pull request as draft October 24, 2025 12:51
@huy209vn huy209vn marked this pull request as ready for review October 31, 2025 15:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant