[Layernorm] Fix autotuner crash and OOB writes in layer_norm_bwd on high-SM GPUs#796
Conversation
Two fixes for layer_norm_bwd_kernel on Blackwell sm_120 (188 SMs):
1. Remove NB from autotuner keys (fwd + bwd kernels)
NB = cdiv(T, 2048) was included in the autotuner key to force
re-autotuning when the token count changed significantly. However,
NB is never referenced in the kernel body — it exists solely as a
cache key. On Blackwell sm_120 with 188 SMs, the Triton autotuner
crashes with "CUDA error: illegal memory access" when benchmarking
the HAS_DRESIDUAL=False kernel variant at NS=188 grid size. Each
new NB value triggered a fresh autotuner benchmark run, repeatedly
hitting this Triton codegen/runtime bug.
Removing NB from the key means the autotuner runs once per
(D, HAS_DRESIDUAL, STORE_DRESIDUAL, IS_RMS_NORM) combination and
reuses the cached config for all T values. The first autotuning
at a small T (safe grid size) produces a config that works for
all subsequent larger T values.
Reproduced with minimal test: RMSNorm(256) backward on T=24000
without residual crashes during autotuner benchmarking. Same
shape with residual (HAS_DRESIDUAL=True) succeeds.
2. Cap NS so BS >= max autotuned BT (64)
On high-SM GPUs with small T (e.g., 450 tokens on 188 SMs),
BS = cdiv(450, 188) = 3. The autotuner tries BT=64, so BS < BT.
Each program's make_block_ptr creates a block of (BT=64, BD)
rows, but the program only owns BS=3 rows. Adjacent programs'
blocks overlap on the output tensor dx, causing write races that
corrupt GPU memory. Capping NS so each program handles at least
_MAX_BT=64 tokens ensures non-overlapping write regions.
For T=450: NS reduced from 188 to 8, BS increased from 3 to 57.
Both fixes include the T_g/T_end fix from the prior commit for
correct make_block_ptr shapes and loop bounds when NS > T // G
(idle programs with no work get empty loop ranges).
Add regression tests:
- test_rmsnorm_varying_nb_no_residual: T in {100..24000}, D=256
Exercises HAS_DRESIDUAL=False across different NB values
- test_rmsnorm_varying_nb_with_residual: same T range
Exercises HAS_DRESIDUAL=True for completeness
- Verified: tests FAIL without fix (crash at T=24000), PASS with
WalkthroughRemoved Changes
Sequence Diagram(s)(omitted — changes are kernel and sizing logic without multi-component sequential interactions warranting a diagram) Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request addresses autotuner crashes on high-SM GPUs by removing 'NB' from the autotune key and ensures that the block size in the backward pass is at least the maximum autotuned block size to prevent memory corruption. A logic error was identified in the calculation of the maximum number of programs per group, where the use of 'triton.cdiv' instead of integer division could still result in overlapping writes.
triton.cdiv rounds up, which can result in BS < _MAX_BT. For example, T_per_group=65: cdiv(65, 64) = 2, BS = cdiv(65, 2) = 33 < 64. Floor division guarantees each program handles at least _MAX_BT tokens: T_per_group=65: 65 // 64 = 1, BS = cdiv(65, 1) = 65 >= 64.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
fla/modules/layernorm.py (1)
656-656: Avoid hard-coding_MAX_BTseparately from autotune configs.Please derive this cap from a shared constant (or from the BT config list) to prevent drift if autotune BT options change later.
Based on learnings: Align threshold constants used by check_shared_mem to a single source of truth to avoid semantic drift.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/modules/layernorm.py` at line 656, The hard-coded _MAX_BT constant should be derived from the autotune BT options instead of being duplicated; modify the code so _MAX_BT is computed (e.g., max_bt = max(BT_CONFIG_LIST) or len/ max of the BT values used by the autotuner) and used by check_shared_mem and any other places that rely on this cap; update references to _MAX_BT in module-level code (symbol: _MAX_BT) and in the check_shared_mem logic to reference the shared computed value (e.g., AUTOTUNE_BT_OPTIONS or the BT config list) so there is a single source of truth for BT limits.tests/modules/test_layernorm.py (1)
377-392: Strengthen the residual-path regression with numerical parity checks.Current assertions verify liveness (
grad is not None, non-zero) but won’t catch silent math regressions. Consider checkingtrivs reference (dx/dw) for this path too.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/modules/test_layernorm.py` around lines 377 - 392, Enhance test_rmsnorm_varying_nb_with_residual to assert numerical parity instead of only liveness: create a reference computation (e.g., run the same forward/backward for RMSNorm on a stable reference device/precision or an explicit reference implementation) using the same inputs x and residual and collect reference gradients (dx_ref, dw_ref), then compare tri.weight.grad and x.grad against the reference (using torch.allclose with reasonable rtol/atol) to catch silent math regressions; ensure you zero gradients before each backward, keep the same RNG/input copies, and use the existing symbols test_rmsnorm_varying_nb_with_residual, RMSNorm, tri, x, residual for locating and wiring the checks.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/modules/layernorm.py`:
- Around line 652-661: The current NS cap uses triton.cdiv(T_per_group, _MAX_BT)
which computes a ceiling and can produce too few programs per group so BS
becomes < _MAX_BT; change the cap to use floor division so NS does not exceed
floor(T_per_group/_MAX_BT)*G (with a minimum of G) — i.e., compute max_ns =
max(T_per_group // _MAX_BT, 1) * G (or equivalent integer floor), then NS =
min(NS, max_ns) so that BS = triton.cdiv(T_per_group, NS//G) will always be >=
_MAX_BT; update the code around symbols _MAX_BT, T_per_group, NS, max_ns, BS
accordingly.
---
Nitpick comments:
In `@fla/modules/layernorm.py`:
- Line 656: The hard-coded _MAX_BT constant should be derived from the autotune
BT options instead of being duplicated; modify the code so _MAX_BT is computed
(e.g., max_bt = max(BT_CONFIG_LIST) or len/ max of the BT values used by the
autotuner) and used by check_shared_mem and any other places that rely on this
cap; update references to _MAX_BT in module-level code (symbol: _MAX_BT) and in
the check_shared_mem logic to reference the shared computed value (e.g.,
AUTOTUNE_BT_OPTIONS or the BT config list) so there is a single source of truth
for BT limits.
In `@tests/modules/test_layernorm.py`:
- Around line 377-392: Enhance test_rmsnorm_varying_nb_with_residual to assert
numerical parity instead of only liveness: create a reference computation (e.g.,
run the same forward/backward for RMSNorm on a stable reference device/precision
or an explicit reference implementation) using the same inputs x and residual
and collect reference gradients (dx_ref, dw_ref), then compare tri.weight.grad
and x.grad against the reference (using torch.allclose with reasonable
rtol/atol) to catch silent math regressions; ensure you zero gradients before
each backward, keep the same RNG/input copies, and use the existing symbols
test_rmsnorm_varying_nb_with_residual, RMSNorm, tri, x, residual for locating
and wiring the checks.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 83d5c329-c760-4653-9a19-03ffe2da7a6f
📒 Files selected for processing (2)
fla/modules/layernorm.pytests/modules/test_layernorm.py
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/modules/layernorm.py`:
- Around line 652-661: The current grouping math can produce BS values not
divisible by the largest backward tile (BT), causing inter-program write
overlaps; update the block-size computation so BS is aligned to a multiple of
the largest BT (_MAX_BT) (or alternatively mask stores by program range).
Specifically, in the code that computes T_per_group, NS, max_ns and BS
(variables T_per_group, NS, max_ns, BS, _MAX_BT), round BS up to the next
multiple of _MAX_BT (e.g., BS = triton.cdiv(BS, _MAX_BT) * _MAX_BT) while still
respecting the group limits (cap BS to T_per_group if needed), or instead modify
the store sites that use boundary_check to clip each program’s store range to
min((i_sg + 1) * BS, Tg) so stores at the locations corresponding to the writes
currently protected only by boundary_check cannot overlap between programs.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: d04994a7-ca74-4c27-a65a-0688ecc39ca4
📒 Files selected for processing (1)
fla/modules/layernorm.py
| # Ensure BS >= max autotuned BT (64) so that each program's | ||
| # make_block_ptr write region doesn't overlap with adjacent | ||
| # programs. On high-SM GPUs with small T, BS < BT causes | ||
| # overlapping writes on dx that corrupt GPU memory. | ||
| _MAX_BT = 64 # largest BT in autotuner configs | ||
| T_per_group = T // G | ||
| NS = triton.cdiv(get_multiprocessor_count(x.device.index), G) * G | ||
| max_ns = max(T_per_group // _MAX_BT, 1) * G | ||
| NS = min(NS, max_ns) | ||
| BS = triton.cdiv(T_per_group, NS // G) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "Affected store sites in fla/modules/layernorm.py:"
nl -ba fla/modules/layernorm.py | sed -n '403,433p'
echo
echo "Counterexamples for the current BS/GS cap:"
python - <<'PY'
def touched_rows(program, bs, bt, tg):
rows = set()
for start in range(program * bs, program * bs + bs, bt):
rows.update(range(start, min(start + bt, tg)))
return rows
MAX_BT = 64
for tg in (129, 130, 191, 257):
gs = max(tg // MAX_BT, 1) # mirrors the current cap when SM count is not the limiter
bs = (tg + gs - 1) // gs
for bt in (64, 32):
overlap = sorted(touched_rows(0, bs, bt, tg) & touched_rows(1, bs, bt, tg))
if overlap:
print(f"Tg={tg} GS={gs} BS={bs} BT={bt} overlap={overlap[0]}..{overlap[-1]}")
PYRepository: fla-org/flash-linear-attention
Length of output: 183
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Affected store sites in fla/modules/layernorm.py (lines 403-433) ==="
cat -n fla/modules/layernorm.py | sed -n '403,433p'
echo ""
echo "=== Code at lines 652-661 (the BS/GS cap) ==="
cat -n fla/modules/layernorm.py | sed -n '652,661p'
echo ""
echo "=== Counterexamples for the current BS/GS cap ==="
python3 <<'PY'
def touched_rows(program, bs, bt, tg):
"""Compute set of rows touched by a program using block-tile stores."""
rows = set()
for start in range(program * bs, program * bs + bs, bt):
rows.update(range(start, min(start + bt, tg)))
return rows
MAX_BT = 64
print("Testing T_per_group values and BT variants:")
for tg in (129, 130, 191, 257):
gs = max(tg // MAX_BT, 1) # mirrors the cap when SM count is not the limiter
bs = (tg + gs - 1) // gs
for bt in (64, 32):
p0_rows = touched_rows(0, bs, bt, tg)
p1_rows = touched_rows(1, bs, bt, tg)
overlap = sorted(p0_rows & p1_rows)
if overlap:
print(f"Tg={tg:3d} GS={gs} BS={bs:3d} BT={bt:2d}: overlap={overlap[0]:3d}..{overlap[-1]:3d}")
else:
print(f"Tg={tg:3d} GS={gs} BS={bs:3d} BT={bt:2d}: NO overlap")
PYRepository: fla-org/flash-linear-attention
Length of output: 2980
BS >= 64 still leaves overlapping backward tiles.
Lines 659–661 cap the number of programs, but they do not ensure each program's span is a multiple of BT. With T_per_group=129 and G=1, this computes GS=2 and BS=65; if the autotuner picks BT=64, program 0 writes rows 0..63 and program 1 writes rows 65..128, so rows 65..127 are written by both programs. The m_t guard on line 411 only protects dw/db accumulation (lines 414, 416); the stores on lines 404, 431, and 433 use only boundary_check=(0, 1), which checks tensor bounds but does not prevent inter-program overlaps. The overlap corrupts GPU memory. Either align BS to a multiple of the largest backward BT, or mask the stores on lines 404, 431, 433 to each program's range: min((i_sg + 1) * BS, Tg).
Multiple concrete counterexamples trigger this: (Tg=129, GS=2, BS=65, BT=64) → rows 65..127 overlap; (Tg=191, GS=2, BS=96, BT=64) → rows 96..127 overlap; (Tg=257, GS=4, BS=65, BT=64) → rows 65..127 overlap.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/modules/layernorm.py` around lines 652 - 661, The current grouping math
can produce BS values not divisible by the largest backward tile (BT), causing
inter-program write overlaps; update the block-size computation so BS is aligned
to a multiple of the largest BT (_MAX_BT) (or alternatively mask stores by
program range). Specifically, in the code that computes T_per_group, NS, max_ns
and BS (variables T_per_group, NS, max_ns, BS, _MAX_BT), round BS up to the next
multiple of _MAX_BT (e.g., BS = triton.cdiv(BS, _MAX_BT) * _MAX_BT) while still
respecting the group limits (cap BS to T_per_group if needed), or instead modify
the store sites that use boundary_check to clip each program’s store range to
min((i_sg + 1) * BS, Tg) so stores at the locations corresponding to the writes
currently protected only by boundary_check cannot overlap between programs.
Summary
Fixes two bugs in
layer_norm_bwd_kernelthat cause "CUDA error: illegal memoryaccess" on high-SM GPUs (Blackwell B200/B300 with 188 SMs). Both bugs are
correctness fixes safe for all hardware.
Relationship to #795
PR #795 fixed a different bug in the same kernel: when
NS > T // G, idleprograms (with no tokens to process) accessed invalid memory through
make_block_ptrbecause the kernel reassignedTand used the capped valueas both the tensor shape and loop bound. The fix separated
T_g(tensor shape)from
T_end(loop bound) so idle programs get empty loop ranges.This PR addresses two additional bugs that #795 did not cover:
Autotuner crash (new bug): Even with correct kernel logic, the Triton
autotuner itself crashes on sm_120 when benchmarking certain kernel variants
(
HAS_DRESIDUAL=False) atNS=188grid size. This is triggered byNBinthe autotuner key forcing re-autotuning for each new T range. Fix layer_norm_bwd_kernel OOB access on high-SM GPUs #795's kernel
fix is correct but the autotuner never gets to run the fixed kernel — it
crashes during the benchmark phase before selecting a config.
Overlapping writes (existing bug): When
BS < BT(small T, many SMs),adjacent programs'
make_block_ptrblocks overlap on the output tensordx.Fix layer_norm_bwd_kernel OOB access on high-SM GPUs #795's
T_g/T_endfix ensures correct shapes and loop bounds, but doesnot prevent the overlapping write regions. This PR caps
NSsoBS >= BT,eliminating the overlap entirely.
Both bugs manifest on high-SM GPUs (188 SMs) and were not caught by #795's
regression tests because those tests used small T values (1–32) where
NSisalso small (capped by our fix). The new
test_rmsnorm_varying_nb_*testsexercise large T values (up to 24000) with full
NS=188to catch both issues.Bug 1: Autotuner crash from phantom NB key
NB = cdiv(T, 2048)was included in the autotuner key but never referenced inthe kernel body. Each new NB value forced re-autotuning, and on Blackwell sm_120,
the Triton autotuner crashes when benchmarking the
HAS_DRESIDUAL=Falsevariantat
NS=188grid size.Fix: Remove
NBfrom autotuner keys (fwd + bwd). The autotuner runs once per(D, HAS_DRESIDUAL, STORE_DRESIDUAL, IS_RMS_NORM)and reuses the config for allT values.
Bug 2: Overlapping writes when BS < BT
On high-SM GPUs with small T (e.g., 450 tokens, 188 SMs),
BS = cdiv(450, 188) = 3.With
BT=64, each program'smake_block_ptrblock covers 64 rows but only owns 3,causing overlapping writes on
dxacross adjacent programs.Fix: Cap
NSsoBS >= max(BT) = 64. For T=450: NS reduced from 188 to 8,BS increased from 3 to 57. Uses floor division (
//) instead ofcdivto guaranteeBS >= BTfor all T values.Reproducer
Environment
Compatibility
All fixes are safe for all hardware:
Test plan
Summary by CodeRabbit
Bug Fixes
Tests