Skip to content

[Layernorm] Fix autotuner crash and OOB writes in layer_norm_bwd on high-SM GPUs#796

Open
mpurland wants to merge 2 commits into
fla-org:mainfrom
mpurland:mpurland/fix/fla-layernorm-bwd-autotuner-crash-high-sm
Open

[Layernorm] Fix autotuner crash and OOB writes in layer_norm_bwd on high-SM GPUs#796
mpurland wants to merge 2 commits into
fla-org:mainfrom
mpurland:mpurland/fix/fla-layernorm-bwd-autotuner-crash-high-sm

Conversation

@mpurland
Copy link
Copy Markdown
Contributor

@mpurland mpurland commented Mar 28, 2026

Summary

Fixes two bugs in layer_norm_bwd_kernel that cause "CUDA error: illegal memory
access" on high-SM GPUs (Blackwell B200/B300 with 188 SMs). Both bugs are
correctness fixes safe for all hardware.

Relationship to #795

PR #795 fixed a different bug in the same kernel: when NS > T // G, idle
programs (with no tokens to process) accessed invalid memory through
make_block_ptr because the kernel reassigned T and used the capped value
as both the tensor shape and loop bound. The fix separated T_g (tensor shape)
from T_end (loop bound) so idle programs get empty loop ranges.

This PR addresses two additional bugs that #795 did not cover:

  1. Autotuner crash (new bug): Even with correct kernel logic, the Triton
    autotuner itself crashes on sm_120 when benchmarking certain kernel variants
    (HAS_DRESIDUAL=False) at NS=188 grid size. This is triggered by NB in
    the autotuner key forcing re-autotuning for each new T range. Fix layer_norm_bwd_kernel OOB access on high-SM GPUs #795's kernel
    fix is correct but the autotuner never gets to run the fixed kernel — it
    crashes during the benchmark phase before selecting a config.

  2. Overlapping writes (existing bug): When BS < BT (small T, many SMs),
    adjacent programs' make_block_ptr blocks overlap on the output tensor dx.
    Fix layer_norm_bwd_kernel OOB access on high-SM GPUs #795's T_g/T_end fix ensures correct shapes and loop bounds, but does
    not prevent the overlapping write regions. This PR caps NS so BS >= BT,
    eliminating the overlap entirely.

Both bugs manifest on high-SM GPUs (188 SMs) and were not caught by #795's
regression tests because those tests used small T values (1–32) where NS is
also small (capped by our fix). The new test_rmsnorm_varying_nb_* tests
exercise large T values (up to 24000) with full NS=188 to catch both issues.

Bug 1: Autotuner crash from phantom NB key

NB = cdiv(T, 2048) was included in the autotuner key but never referenced in
the kernel body. Each new NB value forced re-autotuning, and on Blackwell sm_120,
the Triton autotuner crashes when benchmarking the HAS_DRESIDUAL=False variant
at NS=188 grid size.

Fix: Remove NB from autotuner keys (fwd + bwd). The autotuner runs once per
(D, HAS_DRESIDUAL, STORE_DRESIDUAL, IS_RMS_NORM) and reuses the config for all
T values.

Bug 2: Overlapping writes when BS < BT

On high-SM GPUs with small T (e.g., 450 tokens, 188 SMs), BS = cdiv(450, 188) = 3.
With BT=64, each program's make_block_ptr block covers 64 rows but only owns 3,
causing overlapping writes on dx across adjacent programs.

Fix: Cap NS so BS >= max(BT) = 64. For T=450: NS reduced from 188 to 8,
BS increased from 3 to 57. Uses floor division (//) instead of cdiv to guarantee BS >= BT for all T values.

Reproducer

import torch
from fla.modules import RMSNorm
                               
norm = RMSNorm(256).cuda()                                                                                                                                                                  
# HAS_DRESIDUAL=True works:
x1 = torch.randn(24000, 256, device="cuda", requires_grad=True)                                                                                                                             
norm(x1, residual=torch.randn_like(x1), prenorm=True)[0].sum().backward()  # OK                                                                                                             
# HAS_DRESIDUAL=False crashes:                                                                                                                                                              
x2 = torch.randn(24000, 256, device="cuda", requires_grad=True)                                                                                                                             
norm(x2).sum().backward()  # CUDA illegal memory access                                                                                                                                     

Environment

  • GPU: NVIDIA B200 (188 SMs, sm_120)
  • CUDA: 13.0
  • PyTorch: 2.10.0–2.11.0
  • Triton: 3.6.0
  • Python: 3.14.0 (free-threaded)

Compatibility

All fixes are safe for all hardware:

  • NB removal: Fewer redundant autotuner runs; optimal config is T-independent
  • NS cap: Only activates when T < SM_count × 64; typical batch sizes unaffected

Test plan

  • All existing test_layernorm.py tests pass (no regression)
  • test_rmsnorm_varying_nb_no_residual — T ∈ {100, 500, 5K, 10K, 20K, 24K}, D=256
  • test_rmsnorm_varying_nb_with_residual — same T sweep
  • test_rmsnorm_small_t / test_layernorm_small_t / test_groupnorm_small_t — T ∈ {1..32}
  • Verified: regression tests FAIL without fix, PASS with fix

Summary by CodeRabbit

  • Bug Fixes

    • Fixed autotuner failure in layer normalization when processing large sequence lengths.
    • Improved backward kernel scheduling for stable performance across varying input dimensions.
  • Tests

    • Added regression tests for RMSNorm backward pass with large sequence lengths, with and without residual connections.

  Two fixes for layer_norm_bwd_kernel on Blackwell sm_120 (188 SMs):

  1. Remove NB from autotuner keys (fwd + bwd kernels)

     NB = cdiv(T, 2048) was included in the autotuner key to force
     re-autotuning when the token count changed significantly. However,
     NB is never referenced in the kernel body — it exists solely as a
     cache key. On Blackwell sm_120 with 188 SMs, the Triton autotuner
     crashes with "CUDA error: illegal memory access" when benchmarking
     the HAS_DRESIDUAL=False kernel variant at NS=188 grid size. Each
     new NB value triggered a fresh autotuner benchmark run, repeatedly
     hitting this Triton codegen/runtime bug.

     Removing NB from the key means the autotuner runs once per
     (D, HAS_DRESIDUAL, STORE_DRESIDUAL, IS_RMS_NORM) combination and
     reuses the cached config for all T values. The first autotuning
     at a small T (safe grid size) produces a config that works for
     all subsequent larger T values.

     Reproduced with minimal test: RMSNorm(256) backward on T=24000
     without residual crashes during autotuner benchmarking. Same
     shape with residual (HAS_DRESIDUAL=True) succeeds.

  2. Cap NS so BS >= max autotuned BT (64)

     On high-SM GPUs with small T (e.g., 450 tokens on 188 SMs),
     BS = cdiv(450, 188) = 3. The autotuner tries BT=64, so BS < BT.
     Each program's make_block_ptr creates a block of (BT=64, BD)
     rows, but the program only owns BS=3 rows. Adjacent programs'
     blocks overlap on the output tensor dx, causing write races that
     corrupt GPU memory. Capping NS so each program handles at least
     _MAX_BT=64 tokens ensures non-overlapping write regions.

     For T=450: NS reduced from 188 to 8, BS increased from 3 to 57.

  Both fixes include the T_g/T_end fix from the prior commit for
  correct make_block_ptr shapes and loop bounds when NS > T // G
  (idle programs with no work get empty loop ranges).

  Add regression tests:
    - test_rmsnorm_varying_nb_no_residual: T in {100..24000}, D=256
      Exercises HAS_DRESIDUAL=False across different NB values
    - test_rmsnorm_varying_nb_with_residual: same T range
      Exercises HAS_DRESIDUAL=True for completeness
    - Verified: tests FAIL without fix (crash at T=24000), PASS with
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 28, 2026

Walkthrough

Removed NB from Triton autotuner cache keys for layer-norm kernels and changed backward launch-tiling logic to cap per-group programs relative to a max forward block tile (_MAX_BT = 64); added regression tests exercising large sequence lengths.

Changes

Cohort / File(s) Summary
Kernel autotuner & launch logic
fla/modules/layernorm.py
Removed NB from autotuner cache key lists for forward and backward Triton kernels. Reworked layer_norm_bwd host-side launch sizing: compute tokens per group, clamp number of programs using _MAX_BT = 64, and derive BS from per-group sizing instead of previous NS/BS scheme.
Regression tests
tests/modules/test_layernorm.py
Added parameterized RMSNorm backward regression tests across large T values (D=256) for both no-residual and residual-with-prenorm cases to detect autotuner/tile-related failures.

Sequence Diagram(s)

(omitted — changes are kernel and sizing logic without multi-component sequential interactions warranting a diagram)

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • sustcsonglin
  • yzhangcs

Poem

🐰 A tiny key dropped, NB hops away,
Programs regroup where tokens now stay,
Tiles lined by sixty-four's gentle hand,
Tests bound the borrows across sequence land,
A rabbit cheers — kernels safe and gay! ✨

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Title check ✅ Passed The title clearly and specifically identifies the main fixes: autotuner crash and OOB (out-of-bounds) writes in layer_norm_bwd on high-SM GPUs, directly matching the PR's primary objectives.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses autotuner crashes on high-SM GPUs by removing 'NB' from the autotune key and ensures that the block size in the backward pass is at least the maximum autotuned block size to prevent memory corruption. A logic error was identified in the calculation of the maximum number of programs per group, where the use of 'triton.cdiv' instead of integer division could still result in overlapping writes.

Comment thread fla/modules/layernorm.py Outdated
  triton.cdiv rounds up, which can result in BS < _MAX_BT. For example,
  T_per_group=65: cdiv(65, 64) = 2, BS = cdiv(65, 2) = 33 < 64.

  Floor division guarantees each program handles at least _MAX_BT tokens:
  T_per_group=65: 65 // 64 = 1, BS = cdiv(65, 1) = 65 >= 64.
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
fla/modules/layernorm.py (1)

656-656: Avoid hard-coding _MAX_BT separately from autotune configs.

Please derive this cap from a shared constant (or from the BT config list) to prevent drift if autotune BT options change later.

Based on learnings: Align threshold constants used by check_shared_mem to a single source of truth to avoid semantic drift.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/modules/layernorm.py` at line 656, The hard-coded _MAX_BT constant should
be derived from the autotune BT options instead of being duplicated; modify the
code so _MAX_BT is computed (e.g., max_bt = max(BT_CONFIG_LIST) or len/ max of
the BT values used by the autotuner) and used by check_shared_mem and any other
places that rely on this cap; update references to _MAX_BT in module-level code
(symbol: _MAX_BT) and in the check_shared_mem logic to reference the shared
computed value (e.g., AUTOTUNE_BT_OPTIONS or the BT config list) so there is a
single source of truth for BT limits.
tests/modules/test_layernorm.py (1)

377-392: Strengthen the residual-path regression with numerical parity checks.

Current assertions verify liveness (grad is not None, non-zero) but won’t catch silent math regressions. Consider checking tri vs reference (dx/dw) for this path too.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/modules/test_layernorm.py` around lines 377 - 392, Enhance
test_rmsnorm_varying_nb_with_residual to assert numerical parity instead of only
liveness: create a reference computation (e.g., run the same forward/backward
for RMSNorm on a stable reference device/precision or an explicit reference
implementation) using the same inputs x and residual and collect reference
gradients (dx_ref, dw_ref), then compare tri.weight.grad and x.grad against the
reference (using torch.allclose with reasonable rtol/atol) to catch silent math
regressions; ensure you zero gradients before each backward, keep the same
RNG/input copies, and use the existing symbols
test_rmsnorm_varying_nb_with_residual, RMSNorm, tri, x, residual for locating
and wiring the checks.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/modules/layernorm.py`:
- Around line 652-661: The current NS cap uses triton.cdiv(T_per_group, _MAX_BT)
which computes a ceiling and can produce too few programs per group so BS
becomes < _MAX_BT; change the cap to use floor division so NS does not exceed
floor(T_per_group/_MAX_BT)*G (with a minimum of G) — i.e., compute max_ns =
max(T_per_group // _MAX_BT, 1) * G (or equivalent integer floor), then NS =
min(NS, max_ns) so that BS = triton.cdiv(T_per_group, NS//G) will always be >=
_MAX_BT; update the code around symbols _MAX_BT, T_per_group, NS, max_ns, BS
accordingly.

---

Nitpick comments:
In `@fla/modules/layernorm.py`:
- Line 656: The hard-coded _MAX_BT constant should be derived from the autotune
BT options instead of being duplicated; modify the code so _MAX_BT is computed
(e.g., max_bt = max(BT_CONFIG_LIST) or len/ max of the BT values used by the
autotuner) and used by check_shared_mem and any other places that rely on this
cap; update references to _MAX_BT in module-level code (symbol: _MAX_BT) and in
the check_shared_mem logic to reference the shared computed value (e.g.,
AUTOTUNE_BT_OPTIONS or the BT config list) so there is a single source of truth
for BT limits.

In `@tests/modules/test_layernorm.py`:
- Around line 377-392: Enhance test_rmsnorm_varying_nb_with_residual to assert
numerical parity instead of only liveness: create a reference computation (e.g.,
run the same forward/backward for RMSNorm on a stable reference device/precision
or an explicit reference implementation) using the same inputs x and residual
and collect reference gradients (dx_ref, dw_ref), then compare tri.weight.grad
and x.grad against the reference (using torch.allclose with reasonable
rtol/atol) to catch silent math regressions; ensure you zero gradients before
each backward, keep the same RNG/input copies, and use the existing symbols
test_rmsnorm_varying_nb_with_residual, RMSNorm, tri, x, residual for locating
and wiring the checks.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 83d5c329-c760-4653-9a19-03ffe2da7a6f

📥 Commits

Reviewing files that changed from the base of the PR and between f52529e and b20ea76.

📒 Files selected for processing (2)
  • fla/modules/layernorm.py
  • tests/modules/test_layernorm.py

Comment thread fla/modules/layernorm.py
Comment thread fla/modules/layernorm.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/modules/layernorm.py`:
- Around line 652-661: The current grouping math can produce BS values not
divisible by the largest backward tile (BT), causing inter-program write
overlaps; update the block-size computation so BS is aligned to a multiple of
the largest BT (_MAX_BT) (or alternatively mask stores by program range).
Specifically, in the code that computes T_per_group, NS, max_ns and BS
(variables T_per_group, NS, max_ns, BS, _MAX_BT), round BS up to the next
multiple of _MAX_BT (e.g., BS = triton.cdiv(BS, _MAX_BT) * _MAX_BT) while still
respecting the group limits (cap BS to T_per_group if needed), or instead modify
the store sites that use boundary_check to clip each program’s store range to
min((i_sg + 1) * BS, Tg) so stores at the locations corresponding to the writes
currently protected only by boundary_check cannot overlap between programs.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: d04994a7-ca74-4c27-a65a-0688ecc39ca4

📥 Commits

Reviewing files that changed from the base of the PR and between b20ea76 and 958bd16.

📒 Files selected for processing (1)
  • fla/modules/layernorm.py

Comment thread fla/modules/layernorm.py
Comment on lines +652 to +661
# Ensure BS >= max autotuned BT (64) so that each program's
# make_block_ptr write region doesn't overlap with adjacent
# programs. On high-SM GPUs with small T, BS < BT causes
# overlapping writes on dx that corrupt GPU memory.
_MAX_BT = 64 # largest BT in autotuner configs
T_per_group = T // G
NS = triton.cdiv(get_multiprocessor_count(x.device.index), G) * G
max_ns = max(T_per_group // _MAX_BT, 1) * G
NS = min(NS, max_ns)
BS = triton.cdiv(T_per_group, NS // G)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "Affected store sites in fla/modules/layernorm.py:"
nl -ba fla/modules/layernorm.py | sed -n '403,433p'

echo
echo "Counterexamples for the current BS/GS cap:"
python - <<'PY'
def touched_rows(program, bs, bt, tg):
    rows = set()
    for start in range(program * bs, program * bs + bs, bt):
        rows.update(range(start, min(start + bt, tg)))
    return rows

MAX_BT = 64

for tg in (129, 130, 191, 257):
    gs = max(tg // MAX_BT, 1)  # mirrors the current cap when SM count is not the limiter
    bs = (tg + gs - 1) // gs
    for bt in (64, 32):
        overlap = sorted(touched_rows(0, bs, bt, tg) & touched_rows(1, bs, bt, tg))
        if overlap:
            print(f"Tg={tg} GS={gs} BS={bs} BT={bt} overlap={overlap[0]}..{overlap[-1]}")
PY

Repository: fla-org/flash-linear-attention

Length of output: 183


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Affected store sites in fla/modules/layernorm.py (lines 403-433) ==="
cat -n fla/modules/layernorm.py | sed -n '403,433p'

echo ""
echo "=== Code at lines 652-661 (the BS/GS cap) ==="
cat -n fla/modules/layernorm.py | sed -n '652,661p'

echo ""
echo "=== Counterexamples for the current BS/GS cap ==="
python3 <<'PY'
def touched_rows(program, bs, bt, tg):
    """Compute set of rows touched by a program using block-tile stores."""
    rows = set()
    for start in range(program * bs, program * bs + bs, bt):
        rows.update(range(start, min(start + bt, tg)))
    return rows

MAX_BT = 64

print("Testing T_per_group values and BT variants:")
for tg in (129, 130, 191, 257):
    gs = max(tg // MAX_BT, 1)  # mirrors the cap when SM count is not the limiter
    bs = (tg + gs - 1) // gs
    for bt in (64, 32):
        p0_rows = touched_rows(0, bs, bt, tg)
        p1_rows = touched_rows(1, bs, bt, tg)
        overlap = sorted(p0_rows & p1_rows)
        if overlap:
            print(f"Tg={tg:3d} GS={gs} BS={bs:3d} BT={bt:2d}: overlap={overlap[0]:3d}..{overlap[-1]:3d}")
        else:
            print(f"Tg={tg:3d} GS={gs} BS={bs:3d} BT={bt:2d}: NO overlap")
PY

Repository: fla-org/flash-linear-attention

Length of output: 2980


BS >= 64 still leaves overlapping backward tiles.

Lines 659–661 cap the number of programs, but they do not ensure each program's span is a multiple of BT. With T_per_group=129 and G=1, this computes GS=2 and BS=65; if the autotuner picks BT=64, program 0 writes rows 0..63 and program 1 writes rows 65..128, so rows 65..127 are written by both programs. The m_t guard on line 411 only protects dw/db accumulation (lines 414, 416); the stores on lines 404, 431, and 433 use only boundary_check=(0, 1), which checks tensor bounds but does not prevent inter-program overlaps. The overlap corrupts GPU memory. Either align BS to a multiple of the largest backward BT, or mask the stores on lines 404, 431, 433 to each program's range: min((i_sg + 1) * BS, Tg).

Multiple concrete counterexamples trigger this: (Tg=129, GS=2, BS=65, BT=64) → rows 65..127 overlap; (Tg=191, GS=2, BS=96, BT=64) → rows 96..127 overlap; (Tg=257, GS=4, BS=65, BT=64) → rows 65..127 overlap.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/modules/layernorm.py` around lines 652 - 661, The current grouping math
can produce BS values not divisible by the largest backward tile (BT), causing
inter-program write overlaps; update the block-size computation so BS is aligned
to a multiple of the largest BT (_MAX_BT) (or alternatively mask stores by
program range). Specifically, in the code that computes T_per_group, NS, max_ns
and BS (variables T_per_group, NS, max_ns, BS, _MAX_BT), round BS up to the next
multiple of _MAX_BT (e.g., BS = triton.cdiv(BS, _MAX_BT) * _MAX_BT) while still
respecting the group limits (cap BS to T_per_group if needed), or instead modify
the store sites that use boundary_check to clip each program’s store range to
min((i_sg + 1) * BS, Tg) so stores at the locations corresponding to the writes
currently protected only by boundary_check cannot overlap between programs.

@zhiyuan1i zhiyuan1i changed the title Fix autotuner crash and OOB writes in layer_norm_bwd on high-SM GPUs [Layernorm] Fix autotuner crash and OOB writes in layer_norm_bwd on high-SM GPUs Mar 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant