Skip to content

[2/n] Add sparse softmax to the Triton flash attention kernel#1078

Merged
kaix-nv merged 6 commits intomainfrom
kaix/triton_fa_sparse24
Mar 26, 2026
Merged

[2/n] Add sparse softmax to the Triton flash attention kernel#1078
kaix-nv merged 6 commits intomainfrom
kaix/triton_fa_sparse24

Conversation

@kaix-nv
Copy link
Contributor

@kaix-nv kaix-nv commented Mar 19, 2026

What does this PR do?

Type of change: ?

Type of change: New feature

Add N:M structured sparsity support to the Triton flash attention kernel (modelopt/torch/kernels/triton_fa.py). For every M consecutive key positions in the attention score tile, keeps the top-N values and sets the rest to -inf before softmax. This is applied during prefill only.

Supported patterns: Any N:M where M=4 (N=1,2,3) or M=8 (N=1..4).

  • Sink tokens and dense window blocks for preserving local attention and attention sinks

Performance (TFLOPS at seq_len=16384, RTX 6000):

Pattern TFLOPS % of Dense
Dense 89.3 100%
2:4 (M=4) 69.5 78%
4:8 (M=8) 57.3 64%

Usage

# Add a code snippet demonstrating how to use this
from modelopt.torch.kernels import attention

# 2:4 sparsity (keep top 2 of every 4 K positions)
out = attention(q, k, v, b_start_loc, b_seq_len, max_len,
                sparsity_n=2, sparsity_m=4)

# 4:8 sparsity with sink tokens and dense window
out = attention(q, k, v, b_start_loc, b_seq_len, max_len,
                sparsity_n=4, sparsity_m=8,
                num_sink_tokens=4, dense_window_blocks=2)

# Dense (default, zero overhead)
out = attention(q, k, v, b_start_loc, b_seq_len, max_len)

# Via mtsa.sparsify() on HuggingFace models
import modelopt.torch.sparsity.attention_sparsity as mtsa
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B",
                                              torch_dtype=torch.bfloat16,
                                              device_map="cuda")

# Default config
mtsa.sparsify(model, mtsa.SPARSE_SOFTMAX_DEFAULT)

Testing

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • N:M structured sparse softmax for Triton flash-attention prefill with configurable dense-window and sink-token handling.
  • API

    • attention(...) accepts sparsity_n, sparsity_m, num_sink_tokens, dense_window_size; HF/Triton prefill path propagates them.
  • Configuration

    • New config fields and exported preset to enable/configure Triton N:M sparse softmax with validation.
  • Tests

    • Added GPU tests covering N:M behavior, tile structure, forward/backward correctness.
  • Documentation

    • CHANGELOG and example docs updated with usage and CLI options.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 19, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 19, 2026

Caution

Review failed

Pull request was closed or merged during review

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds Triton-backed N:M sparse-softmax to flash attention: kernel tile helpers and constexpr params, autograd and public API plumbing, sparsity method/config registration and HF prefill gating, plus GPU tests and a changelog entry.

Changes

Cohort / File(s) Summary
Changelog
CHANGELOG.rst
Added 0.44 entry documenting N:M sparse softmax support for the Triton flash-attention kernel.
Triton flash-attention core
modelopt/torch/kernels/triton_fa.py
Added Triton JIT helpers for N:M masking, dense-region gating, new constexpr params (SPARSITY_N, SPARSITY_M, NUM_SINK_TOKENS, DENSE_WINDOW_SIZE), conditional masking in forward and mirrored gating in backward recomputation, and propagated sparsity/window args through autograd _Attention and public attention(...).
HF wrapper integration
modelopt/torch/kernels/hf_triton_attention.py
Prefill path now injects Triton kernel kwargs (sparsity_n, sparsity_m, num_sink_tokens, dense_window_size) when a sparse-method instance toggles _apply_sparse_nm; decode unchanged.
Sparsity config & default
modelopt/torch/sparsity/attention_sparsity/config.py
Added config fields (sparsity_n, sparsity_m, num_sink_tokens, dense_window_size), per-field validators (e.g., sparsity_m ∈ {4,8}, non-negative checks), cross-field cap on sparsity_n, and exported SPARSE_SOFTMAX_DEFAULT.
Methods package init
modelopt/torch/sparsity/attention_sparsity/methods/__init__.py
Now imports triton_sparse_softmax at package init so the Triton-backed method registers on import.
Triton sparse-softmax method
modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py
Added TritonSparseSoftmaxMethod (name triton_sparse_softmax) storing sparsity params and providing get_sparse_context() which toggles module._apply_sparse_nm during the context.
Method base / registry
modelopt/torch/sparsity/attention_sparsity/methods/registry.py
Made calculate_sparsity() return an all-True boolean mask by default and implemented apply_sparsity() to raise NotImplementedError (kernel fusion expected); removed abstract requirement for these two methods.
Test utilities
tests/gpu/torch/sparsity/attention_sparsity/conftest.py
New GPU test fixtures/helpers: make_qkv, make_varlen_meta, sdpa_reference, and tiny_llama_dir to centralize test data creation and SDPA reference computation.
Tests (refactor + dense checks)
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
Refactored to use shared conftest helpers, reorganized forward/backward tests, added test_sparse_disabled_matches_dense to assert bit-identical dense behavior when sparsity_n=0, and updated HF integration tests to exercise sparsify config.
New N:M sparsity tests
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
New Triton-gated GPU tests for prefill N:M behavior: end-to-end checks (TestSparseNM), tile-level unit tests for _apply_sparse_nm_to_qk_tile (TestSparseTileStructure), and backward gradient sanity tests (TestSparseBackward).
Examples / docs
examples/llm_sparsity/attention_sparsity/README.md, examples/llm_sparsity/attention_sparsity/hf_sa.py
Documented N:M sparse softmax, added SPARSE_SOFTMAX_DEFAULT usage example, updated CLI choices to include sparse_softmax, adjusted backend override behavior, and noted N:M is applied only during prefill (Triton backend).

Sequence Diagram(s)

sequenceDiagram
    participant Caller as "Caller / Model"
    participant API as "attention(...) / _Attention"
    participant Autograd as "Autograd Function"
    participant Triton as "Triton FA Kernel"
    participant KV as "KV Storage / Tiles"

    rect rgba(100,149,237,0.5)
    Caller->>API: call attention(q,k,v,..., sparsity_n, sparsity_m, num_sink_tokens, dense_window_size)
    API->>Autograd: _Attention.forward(..., sparsity params)
    Autograd->>Triton: launch forward kernel (constexpr sparsity params)
    Triton->>KV: load QK tile
    Triton->>Triton: apply N:M mask (tile-level) -> set pruned scores to -inf
    Triton->>Triton: softmax & compute output context
    Triton-->>Autograd: return outputs + saved tensors (incl. sparsity params)
    end

    rect rgba(60,179,113,0.5)
    Autograd->>Triton: backward launch with same sparsity params
    Triton->>Triton: recompute masked scores (respect sink/window), compute dq/dk/dv
    Triton-->>Autograd: return gradients
    Autograd-->>Caller: propagate gradients
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title directly and specifically describes the main feature addition: N:M sparse softmax support for the Triton flash attention kernel.
Docstring Coverage ✅ Passed Docstring coverage is 90.57% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed Pull request code comprehensively reviewed against all security anti-patterns in SECURITY.md. No violations found across all six categories.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch kaix/triton_fa_sparse24

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Mar 19, 2026

Codecov Report

❌ Patch coverage is 58.92857% with 23 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.18%. Comparing base (291498b) to head (9fdeccd).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
...ttention_sparsity/methods/triton_sparse_softmax.py 36.36% 14 Missing ⚠️
...delopt/torch/sparsity/attention_sparsity/config.py 75.00% 8 Missing ⚠️
...ch/sparsity/attention_sparsity/methods/registry.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1078      +/-   ##
==========================================
- Coverage   70.21%   70.18%   -0.03%     
==========================================
  Files         228      229       +1     
  Lines       25952    26008      +56     
==========================================
+ Hits        18221    18254      +33     
- Misses       7731     7754      +23     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kaix-nv kaix-nv force-pushed the kaix/triton_fa_sparse24 branch 2 times, most recently from 8ba6efe to 7aa6960 Compare March 20, 2026 04:47
@kaix-nv kaix-nv marked this pull request as ready for review March 20, 2026 05:16
@kaix-nv kaix-nv requested a review from a team as a code owner March 20, 2026 05:16
@kaix-nv kaix-nv requested review from ChenhanYu, Edwardf0t1, cjluo-nv, kevalmorabia97 and rohansjoshi and removed request for ChenhanYu and Edwardf0t1 March 20, 2026 05:16
@kaix-nv kaix-nv force-pushed the kaix/triton_fa_sparse24 branch from 7aa6960 to 31655ce Compare March 21, 2026 19:43
@kaix-nv kaix-nv requested a review from a team as a code owner March 21, 2026 19:43
@kaix-nv kaix-nv changed the title Add 2:4 sparse softmax to the Triton flash attention kernel Add sparse softmax to the Triton flash attention kernel Mar 21, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 279-292: The N:M sparsity branch is being applied during decode
because it only checks SPARSITY_N > 0; change the condition to skip
sparsification when doing cached decode (seq_len_q == 1). Update the if guarding
the block (the one that currently reads "if SPARSITY_N > 0:") to also require
not decoding (e.g., "if SPARSITY_N > 0 and seq_len_q != 1:") so
_apply_sparse_nm_to_qk_tile(scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M) is
only called during prefill/non-decoding paths; keep the existing local/sink
logic (kv_start, tile_q, q_abs_block, is_local/is_sink) unchanged.
- Around line 279-292: The sparse-mask logic currently mixes tile-sized units
(BLOCK_M/BLOCK_N) with sparsity parameters (NUM_SINK_BLOCKS,
DENSE_WINDOW_BLOCKS) leading to inconsistent masks between forward/backward; fix
by computing mask membership in token-space instead of tile-space: derive each
KV block index and query row absolute token position from actual token counts
(use seq_len_kv, seq_len_q, kv_start and per-row start = tile_q * BLOCK_M +
row_offset or for whole-tile use tile_token_start = tile_q * BLOCK_M) and then
map those token positions into logical token-blocks of a fixed reference block
size (choose the constant used by backward kernels, e.g., 64 tokens) before
comparing to NUM_SINK_BLOCKS and DENSE_WINDOW_BLOCKS; update q_abs_block,
kv_block_idx, is_sink and is_local computations (the branch that calls
_apply_sparse_nm_to_qk_tile) and apply same normalization in the other
occurrences mentioned (around the other two locations) so forward and backward
use the same token-blocking semantics.

In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 99-130: Add validation to the config so invalid N:M and negative
counts are rejected when building the config instead of failing in the Triton
kernel: in the class/constructor where ModeloptField(s) sparsity_n, sparsity_m,
num_sink_blocks, and dense_window_blocks are defined, validate that sparsity_m
is either 4 or 8, sparsity_n is in {1,2,3} when sparsity_m==4 and in {1,2,3,4}
when sparsity_m==8 (or 0/disabled as your semantics require), and that
num_sink_blocks and dense_window_blocks are non-negative; also ensure that the
chosen sparsity mode is only allowed when triton_sparse_softmax is selected or
available and raise a clear config validation error if any rule is violated so
bad configs fail early.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py`:
- Around line 350-358: Add an explicit assertion after calling
mtsa.sparsify(...) to verify the Triton backend was applied: check that
model_sparse.config._attn_implementation == "modelopt_triton" (this should be
done right after mtsa.sparsify(...) returns and before comparing logits). Locate
the sparsification call (mtsa.sparsify(..., backend="triton", ...)) and add the
assertion referencing model_sparse.config._attn_implementation to ensure the
Triton kernel registration took effect.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ce048172-0588-4aec-9474-44eb6c4cbe3b

📥 Commits

Reviewing files that changed from the base of the PR and between 7aa6960 and 31655ce.

📒 Files selected for processing (9)
  • CHANGELOG.rst
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py
  • tests/gpu/torch/sparsity/attention_sparsity/conftest.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
✅ Files skipped from review due to trivial changes (2)
  • modelopt/torch/sparsity/attention_sparsity/methods/init.py
  • CHANGELOG.rst
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/gpu/torch/sparsity/attention_sparsity/conftest.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py

@kaix-nv kaix-nv changed the title Add sparse softmax to the Triton flash attention kernel [2/n] Add sparse softmax to the Triton flash attention kernel Mar 23, 2026
@kaix-nv kaix-nv requested review from Edwardf0t1 March 24, 2026 00:16
mask = torch.ones_like(attention_scores, dtype=torch.bool)
return mask, {}

def apply_sparsity(self, attention_scores, sparse_mask=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better if, for Triton-backend attention sparsity methods, we don't have to add the boilerplate in this file (i.e. not have to add a new file to methods/ altogether)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I've made calculate_sparsity and apply_sparsity optional in base class.

@kaix-nv
Copy link
Contributor Author

kaix-nv commented Mar 24, 2026

@coderabbitai full review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 24, 2026

✅ Actions performed

Full review triggered.

@github-actions
Copy link
Contributor

github-actions bot commented Mar 24, 2026

PR Preview Action v1.8.1
Preview removed because the pull request was closed.
2026-03-26 04:31 UTC

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (4)
modelopt/torch/kernels/triton_fa.py (4)

619-631: ⚠️ Potential issue | 🟠 Major

Backward dK/dV: same tile-based locality issue applies here.

See comment on forward kernel (lines 279-292). The fix should be applied consistently here.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 619 - 631, The backward
dK/dV sparse-N mask is missing the same tile-locality handling as the forward
kernel; update the backward block (the code around where scores are processed in
the backward pass) to perform the same token-level/tile locality checks (compute
is_sink using kv_start and NUM_SINK_TOKENS, compute causal_offset = seq_len_kv -
seq_len_q, q_abs_pos = qi * BLOCK_M + causal_offset, token_distance = q_abs_pos
- kv_start, is_local = (token_distance >= 0) and (token_distance <
DENSE_WINDOW_SIZE)) and only call _apply_sparse_nm_to_qk_tile(scores, BLOCK_M,
BLOCK_N, SPARSITY_N, SPARSITY_M) when not is_sink and not is_local, mirroring
the forward kernel's logic so dK/dV uses identical N:M sparsity masking
decisions.

280-280: ⚠️ Potential issue | 🟠 Major

Add decode guard — sparsity should be prefill-only per PR description.

The PR states "sparsity masking is applied during prefill only," but the kernel only checks SPARSITY_N > 0. During decode (seq_len_q == 1), the KV cache will still be sparsified, which could degrade generation quality.

🔧 Proposed fix
 if SPARSITY_N > 0:
+    # Skip sparsity during decode (seq_len_q == 1) — apply only during prefill
+    is_decode = seq_len_q == 1
     is_sink = kv_start < NUM_SINK_TOKENS
     causal_offset = seq_len_kv - seq_len_q
     q_abs_pos = tile_q * BLOCK_M + causal_offset
     token_distance = q_abs_pos - kv_start
     is_local = (token_distance >= 0) and (token_distance < DENSE_WINDOW_SIZE)
-    if not is_sink and not is_local:
+    if not is_decode and not is_sink and not is_local:
         scores = _apply_sparse_nm_to_qk_tile(...)

Apply the same guard in _attn_bwd_dq (line 481) and _attn_bwd_dkdv (line 621).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` at line 280, The sparsity mask is
currently applied unconditionally when SPARSITY_N > 0; change the guard in the
backward kernels _attn_bwd_dq and _attn_bwd_dkdv so sparsity is only applied
during prefill (i.e., skip when decoding with seq_len_q == 1). Locate the
existing "if SPARSITY_N > 0:" checks inside _attn_bwd_dq and _attn_bwd_dkdv and
strengthen them to also require seq_len_q != 1 (for example: if SPARSITY_N > 0
and seq_len_q != 1) so the KV cache is not sparsified during decode.

480-491: ⚠️ Potential issue | 🟠 Major

Backward dQ: same tile-based locality issue applies here.

See comment on forward kernel (lines 279-292). The fix should be applied consistently here.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 480 - 491, This backward dQ
block repeats the same tile-locality bug as the forward kernel; replicate the
forward-kernel fix (the block-based locality check used around lines 279-292)
here: compute the tile/block-level positions using tile_q, BLOCK_M and the
corresponding kv tile start (using tile_k or kv_start aligned to BLOCK_N),
derive q_abs_pos and token_distance at block granularity, set is_local based on
the block/window bounds (and preserve the is_sink check with NUM_SINK_TOKENS),
and only call _apply_sparse_nm_to_qk_tile(scores, BLOCK_M, BLOCK_N, SPARSITY_N,
SPARSITY_M) when the tile is neither sink nor local—i.e., exactly mirror the
forward kernel's locality logic and conditions so forward and backward agree.

279-292: ⚠️ Potential issue | 🟠 Major

Forward/backward sparse mask mismatch due to autotuned vs fixed tile sizes.

q_abs_pos = tile_q * BLOCK_M + causal_offset depends on BLOCK_M, but forward autotunes BLOCK_M over {64, 128} while backward hardcodes BLOCK = 64. For the same query row, the computed q_abs_pos and thus is_local can differ between forward and backward, causing gradient mismatch.

For example, query position 70 with causal_offset=0:

  • Forward (BLOCK_M=128): tile_q=0, q_abs_pos=0
  • Backward (BLOCK_M=64): tile_q=1, q_abs_pos=64

This changes whether tiles fall within DENSE_WINDOW_SIZE, applying different sparse masks.

Consider computing locality per-row (using the actual row offset within the tile) rather than per-tile:

 if SPARSITY_N > 0:
     is_sink = kv_start < NUM_SINK_TOKENS
     causal_offset = seq_len_kv - seq_len_q
-    q_abs_pos = tile_q * BLOCK_M + causal_offset
-    token_distance = q_abs_pos - kv_start
-    is_local = (token_distance >= 0) and (token_distance < DENSE_WINDOW_SIZE)
-    if not is_sink and not is_local:
+    # Per-row locality: check if ANY row in this Q tile is within dense window
+    q_tile_start = tile_q * BLOCK_M + causal_offset
+    q_tile_end = q_tile_start + BLOCK_M - 1
+    # Tile overlaps dense window if any Q row is within DENSE_WINDOW_SIZE of kv_start
+    tile_overlaps_window = (q_tile_start - kv_start < DENSE_WINDOW_SIZE) and (q_tile_start >= kv_start)
+    if not is_sink and not tile_overlaps_window:
         scores = _apply_sparse_nm_to_qk_tile(...)

Alternatively, fix BLOCK_M to match backward's block size when sparsity is enabled.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 279 - 292, The
sparse/locality check uses q_abs_pos = tile_q * BLOCK_M which diverges between
forward (autotuned BLOCK_M) and backward (fixed BLOCK=64); fix by computing
locality per-row instead of per-tile: derive the absolute query row index as
(tile_q * BLOCK_M + row_within_tile) where row_within_tile is the actual row
offset inside the current tile (from the loop/index that produces scores), then
use that q_abs_pos for is_local/is_sink logic before calling
_apply_sparse_nm_to_qk_tile; alternatively, if you prefer the simpler change,
force BLOCK_M to the backward block size (64) whenever SPARSITY_N > 0 so forward
and backward use the same tile size.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 619-631: The backward dK/dV sparse-N mask is missing the same
tile-locality handling as the forward kernel; update the backward block (the
code around where scores are processed in the backward pass) to perform the same
token-level/tile locality checks (compute is_sink using kv_start and
NUM_SINK_TOKENS, compute causal_offset = seq_len_kv - seq_len_q, q_abs_pos = qi
* BLOCK_M + causal_offset, token_distance = q_abs_pos - kv_start, is_local =
(token_distance >= 0) and (token_distance < DENSE_WINDOW_SIZE)) and only call
_apply_sparse_nm_to_qk_tile(scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M)
when not is_sink and not is_local, mirroring the forward kernel's logic so dK/dV
uses identical N:M sparsity masking decisions.
- Line 280: The sparsity mask is currently applied unconditionally when
SPARSITY_N > 0; change the guard in the backward kernels _attn_bwd_dq and
_attn_bwd_dkdv so sparsity is only applied during prefill (i.e., skip when
decoding with seq_len_q == 1). Locate the existing "if SPARSITY_N > 0:" checks
inside _attn_bwd_dq and _attn_bwd_dkdv and strengthen them to also require
seq_len_q != 1 (for example: if SPARSITY_N > 0 and seq_len_q != 1) so the KV
cache is not sparsified during decode.
- Around line 480-491: This backward dQ block repeats the same tile-locality bug
as the forward kernel; replicate the forward-kernel fix (the block-based
locality check used around lines 279-292) here: compute the tile/block-level
positions using tile_q, BLOCK_M and the corresponding kv tile start (using
tile_k or kv_start aligned to BLOCK_N), derive q_abs_pos and token_distance at
block granularity, set is_local based on the block/window bounds (and preserve
the is_sink check with NUM_SINK_TOKENS), and only call
_apply_sparse_nm_to_qk_tile(scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M)
when the tile is neither sink nor local—i.e., exactly mirror the forward
kernel's locality logic and conditions so forward and backward agree.
- Around line 279-292: The sparse/locality check uses q_abs_pos = tile_q *
BLOCK_M which diverges between forward (autotuned BLOCK_M) and backward (fixed
BLOCK=64); fix by computing locality per-row instead of per-tile: derive the
absolute query row index as (tile_q * BLOCK_M + row_within_tile) where
row_within_tile is the actual row offset inside the current tile (from the
loop/index that produces scores), then use that q_abs_pos for is_local/is_sink
logic before calling _apply_sparse_nm_to_qk_tile; alternatively, if you prefer
the simpler change, force BLOCK_M to the backward block size (64) whenever
SPARSITY_N > 0 so forward and backward use the same tile size.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 515c68d0-8207-4d6f-9009-6aa727c66189

📥 Commits

Reviewing files that changed from the base of the PR and between 31655ce and 9882dbb.

📒 Files selected for processing (6)
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (4)
tests/gpu/torch/sparsity/attention_sparsity/conftest.py (1)

1-1: Update copyright year to 2026.

The license header has copyright year 2024, but the current year is 2026. Consider updating for consistency with the PR date.

📝 Suggested fix
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/sparsity/attention_sparsity/conftest.py` at line 1, Update
the SPDX license header year from 2024 to 2026 at the top of the conftest.py
file (the file-level comment starting with "SPDX-FileCopyrightText") so the
copyright line reflects the current year.
modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py (1)

1-1: Update copyright year to 2026.

The license header has copyright year 2024, but this is a new file created in 2026.

📝 Suggested fix
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py`
at line 1, Update the SPDX copyright header in the file that currently reads
"Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES" to use the year 2026;
locate the header at the top of the file (the SPDX/FileCopyrightText comment)
and change the year to 2026 so the license header reflects the file creation
year.
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py (2)

232-243: Clarify assertion expectation for M=8 pattern.

The test parametrizes over both M=4 and M=8 patterns, but line 241's assertion (kept == n).all() expects exactly N kept. For M=8 with tl.sort-based thresholding, ties may keep ≥N elements. While random inputs (line 236) make ties unlikely, consider adding a comment or splitting the test.

The separate test_sparsity_structure_ties at lines 250-264 correctly handles this distinction, so this is a minor documentation concern.

📝 Suggested clarification
     def test_sparsity_structure(self, n, m):
-        """Verify N:M structure: exactly N kept per group of M."""
+        """Verify N:M structure: exactly N kept per group of M (random input avoids ties)."""
         bm, bn = 32, 64
         torch.manual_seed(88)
         tile = torch.randn(bm, bn, device="cuda", dtype=torch.float32)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py`
around lines 232 - 243, test_sparsity_structure currently asserts exactly N kept
per group with (kept == n).all() but for M=8 the tl.sort-based thresholding can
legitimately keep >=N on ties; update the test to either relax the check for
m==8 (use (kept >= n).all() when m == 8) or add a clear comment next to the
assertion explaining that ties for M=8 may produce >=N and that
test_sparsity_structure_ties covers strict tie behavior; reference
test_sparsity_structure, the kept variable, and the call to
_test_apply_sparse_nm so reviewers can find and apply the change.

1-1: Update copyright year to 2026.

New file should have current copyright year.

📝 Suggested fix
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py` at
line 1, Update the SPDX copyright header line that currently reads "Copyright
(c) 2024 NVIDIA CORPORATION & AFFILIATES." to the current year 2026; locate the
SPDX header (the line beginning with "SPDX-FileCopyrightText") in
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py and
replace 2024 with 2026 so the SPDX header reflects Copyright (c) 2026 NVIDIA
CORPORATION & AFFILIATES.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py`:
- Line 1: Update the SPDX copyright header in the file that currently reads
"Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES" to use the year 2026;
locate the header at the top of the file (the SPDX/FileCopyrightText comment)
and change the year to 2026 so the license header reflects the file creation
year.

In `@tests/gpu/torch/sparsity/attention_sparsity/conftest.py`:
- Line 1: Update the SPDX license header year from 2024 to 2026 at the top of
the conftest.py file (the file-level comment starting with
"SPDX-FileCopyrightText") so the copyright line reflects the current year.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py`:
- Around line 232-243: test_sparsity_structure currently asserts exactly N kept
per group with (kept == n).all() but for M=8 the tl.sort-based thresholding can
legitimately keep >=N on ties; update the test to either relax the check for
m==8 (use (kept >= n).all() when m == 8) or add a clear comment next to the
assertion explaining that ties for M=8 may produce >=N and that
test_sparsity_structure_ties covers strict tie behavior; reference
test_sparsity_structure, the kept variable, and the call to
_test_apply_sparse_nm so reviewers can find and apply the change.
- Line 1: Update the SPDX copyright header line that currently reads "Copyright
(c) 2024 NVIDIA CORPORATION & AFFILIATES." to the current year 2026; locate the
SPDX header (the line beginning with "SPDX-FileCopyrightText") in
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py and
replace 2024 with 2026 so the SPDX header reflects Copyright (c) 2026 NVIDIA
CORPORATION & AFFILIATES.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: dac2310d-e99e-47c1-b62c-962ab622bdd4

📥 Commits

Reviewing files that changed from the base of the PR and between 08e5f92 and 9882dbb.

📒 Files selected for processing (9)
  • CHANGELOG.rst
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py
  • tests/gpu/torch/sparsity/attention_sparsity/conftest.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py

@kaix-nv kaix-nv force-pushed the kaix/triton_fa_sparse24 branch from 9882dbb to 67ae67b Compare March 24, 2026 20:56
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
modelopt/torch/sparsity/attention_sparsity/methods/registry.py (1)

84-94: Consider adding @abstractmethod decorator for consistency.

get_sparse_context raises NotImplementedError but lacks the @abstractmethod decorator, unlike the name property (line 109). This inconsistency means subclasses won't get a clear error at instantiation time if they forget to implement it—they'll only fail at runtime when the method is called.

✨ Suggested change
+    `@abstractmethod`
     def get_sparse_context(self, module: torch.nn.Module):
         """Return a context manager that activates this method's sparsity during forward.

         Each method subclass implements its own activation mechanism:
         - Softmax-patching methods replace F.softmax during the forward pass.
         - Kernel-fused methods set flags on ``module`` that the kernel reads.

         Args:
             module: The SparseAttentionModule wrapping the attention layer.
         """
-        raise NotImplementedError(f"{type(self).__name__} must implement get_sparse_context()")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/registry.py` around lines
84 - 94, get_sparse_context currently raises NotImplementedError but isn't
decorated as abstract, causing subclasses to only error at runtime; mark
get_sparse_context with the `@abstractmethod` decorator (same style used for the
name property) so subclasses of the registry base class must implement
get_sparse_context (which should return a context manager for activating
sparsity on the SparseAttentionModule) and ensure imports/ABC usage are
consistent with the existing abstract methods.
modelopt/torch/kernels/triton_fa.py (1)

866-871: Consider adding input validation for sparsity parameters.

The kernel enforces constraints via tl.static_assert (SPARSITY_M must be 4 or 8), but invalid combinations at the Python API level could produce unexpected behavior. For example, sparsity_n >= sparsity_m or sparsity_n < 0 won't be caught until kernel compilation.

✨ Suggested validation
 def attention(
     q: torch.Tensor,
     ...
     *,
     sparsity_n: int = 0,
     sparsity_m: int = 4,
     num_sink_tokens: int = 0,
     dense_window_size: int = 64,
 ) -> torch.Tensor:
+    if sparsity_n < 0:
+        raise ValueError(f"sparsity_n must be non-negative, got {sparsity_n}")
+    if sparsity_n > 0:
+        if sparsity_m not in (4, 8):
+            raise ValueError(f"sparsity_m must be 4 or 8, got {sparsity_m}")
+        if sparsity_n >= sparsity_m:
+            raise ValueError(f"sparsity_n ({sparsity_n}) must be < sparsity_m ({sparsity_m})")
     ...
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 866 - 871, Add Python-level
input validation at the start of the function that accepts sparsity_n and
sparsity_m (the function with parameters sparsity_n: int = 0, sparsity_m: int =
4, num_sink_tokens, dense_window_size) to prevent invalid combos before kernel
compilation: check that sparsity_m is one of the supported values (4 or 8),
sparsity_n is an int >= 0 and strictly less than sparsity_m, and that both are
integers; if any check fails, raise a ValueError with a clear message
referencing sparsity_n/sparsity_m so users see the invalid values. Ensure these
checks run before any tl.static_assert or kernel compilation logic so invalid
inputs are caught early.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 866-871: Add Python-level input validation at the start of the
function that accepts sparsity_n and sparsity_m (the function with parameters
sparsity_n: int = 0, sparsity_m: int = 4, num_sink_tokens, dense_window_size) to
prevent invalid combos before kernel compilation: check that sparsity_m is one
of the supported values (4 or 8), sparsity_n is an int >= 0 and strictly less
than sparsity_m, and that both are integers; if any check fails, raise a
ValueError with a clear message referencing sparsity_n/sparsity_m so users see
the invalid values. Ensure these checks run before any tl.static_assert or
kernel compilation logic so invalid inputs are caught early.

In `@modelopt/torch/sparsity/attention_sparsity/methods/registry.py`:
- Around line 84-94: get_sparse_context currently raises NotImplementedError but
isn't decorated as abstract, causing subclasses to only error at runtime; mark
get_sparse_context with the `@abstractmethod` decorator (same style used for the
name property) so subclasses of the registry base class must implement
get_sparse_context (which should return a context manager for activating
sparsity on the SparseAttentionModule) and ensure imports/ABC usage are
consistent with the existing abstract methods.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 8083fa36-d78b-4fde-9ad9-9ea3c8eb1666

📥 Commits

Reviewing files that changed from the base of the PR and between 9882dbb and 67ae67b.

📒 Files selected for processing (10)
  • CHANGELOG.rst
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/methods/registry.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py
  • tests/gpu/torch/sparsity/attention_sparsity/conftest.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
✅ Files skipped from review due to trivial changes (3)
  • CHANGELOG.rst
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/sparsity/attention_sparsity/methods/init.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py

Copy link
Contributor

@Edwardf0t1 Edwardf0t1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, left some comments.


# N:M sparse softmax — prefill only (decode should not sparsify KV)
if not is_decode and getattr(module, "_apply_sparse_nm", False):
method = getattr(module, "_sparse_method_instance", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where _sparse_method_instance gets set, if it's outside this PR, please add a comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

Comment on lines +283 to +292
is_sink = kv_start < NUM_SINK_TOKENS
# causal_offset handles chunked prefill: q starts at (seq_len_kv - seq_len_q)
causal_offset = seq_len_kv - seq_len_q
q_abs_pos = tile_q * BLOCK_M + causal_offset
token_distance = q_abs_pos - kv_start
is_local = (token_distance >= 0) and (token_distance < DENSE_WINDOW_SIZE)
if not is_sink and not is_local:
scores = _apply_sparse_nm_to_qk_tile(
scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider extracting the duplicated sink/window check into a shared helper function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, fixed.

kaix-nv added 5 commits March 25, 2026 19:22
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
@kaix-nv kaix-nv force-pushed the kaix/triton_fa_sparse24 branch from 67ae67b to ee10e20 Compare March 26, 2026 02:53
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
modelopt/torch/sparsity/attention_sparsity/config.py (1)

228-238: ⚠️ Potential issue | 🟠 Major

Validate triton_sparse_softmax against backend="triton" here.

This validator still accepts {"method": "triton_sparse_softmax", "backend": "pytorch"}. The example CLI can create exactly that config, and TritonSparseSoftmaxMethod relies on the fused Triton path rather than a PyTorch implementation, so this should fail during config validation instead of later.

🛠️ Suggested validation
     `@model_validator`(mode="after")
     def validate_sparsity_n_vs_m(self):
         """Validate sparsity_n is within the supported range for the given sparsity_m."""
+        if self.method == "triton_sparse_softmax" and self.backend != "triton":
+            raise ValueError("triton_sparse_softmax requires backend='triton'")
         if self.sparsity_n > 0:
             max_n = 3 if self.sparsity_m == 4 else self.sparsity_m - 1
             if self.sparsity_n > max_n:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 228 - 238,
The current validator validate_sparsity_n_vs_m only checks n vs m and misses
validating that method=="triton_sparse_softmax" requires backend=="triton";
update that validator (or add a new `@model_validator`(mode="after") in the same
config class) to check self.method (or the enum value for
TritonSparseSoftmaxMethod) and raise a ValueError if self.method ==
"triton_sparse_softmax" and self.backend != "triton", giving a clear message
like "triton_sparse_softmax requires backend='triton'"; keep this validation in
the config class so invalid combos are rejected during config validation.
modelopt/torch/kernels/triton_fa.py (2)

308-321: ⚠️ Potential issue | 🟠 Major

Keep the prefill-only contract inside the public Triton API too.

These branches only gate on SPARSITY_N > 0, so direct calls to attention(..., b_start_loc_k=..., b_seq_len_k=..., sparsity_n>0) will sparsify cached decode KV even though the feature is documented as prefill-only. The HF wrapper avoids passing these kwargs on decode, but the public kernel API now exposes them directly, so forward and both backward recompute paths should enforce the same guard.

🛠️ Suggested guard
-        if SPARSITY_N > 0:
+        is_decode = seq_len_q == 1 and seq_len_kv > 1
+        if SPARSITY_N > 0 and not is_decode:

Apply the same predicate in _attn_fwd, _attn_bwd_dq, and _attn_bwd_dkdv.

Also applies to: 509-523, 651-665

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 308 - 321, The N:M sparsity
branches currently run whenever SPARSITY_N > 0, which lets external callers
apply sparsity during decode; update the public Triton kernel entry points to
enforce the prefill-only contract by combining the SPARSITY_N check with the
same prefill predicate used by the HF wrapper (i.e., only apply sparsification
when this is a prefill call — the decode-specific KV indices/lengths are not
provided). Concretely, in _attn_fwd, _attn_bwd_dq, and _attn_bwd_dkdv wrap the
existing SPARSITY_N > 0 blocks with an additional condition that b_start_loc_k
and b_seq_len_kv indicate a prefill call (for example, only run the sparse
branch when b_start_loc_k and b_seq_len_kv are unset/None or otherwise denote
prefill), so cached decode KV paths are not sparsified.

161-185: ⚠️ Potential issue | 🔴 Critical

The “token” dense-region knobs are still quantized by launch tiles.

_is_dense_region() runs once per (q_tile, kv_tile) and only looks at kv_start and tile_q * BLOCK_M. So num_sink_tokens=4 keeps the entire first KV tile dense (32/64/128 tokens depending on BLOCK_N), not the first 4 tokens, and the local-window decision changes when forward autotunes BLOCK_M/BLOCK_N while backward recomputes with fixed 64x64 tiles. For example, q[64:127] shares tile_q=0 in forward when BLOCK_M=128, but both backward kernels recompute it as tile_q=1, so dense_window_size=64 flips kv_start=64 from sparse to dense. That makes backward differentiate a different mask than forward.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 161 - 185, _is_dense_region
currently quantizes decisions to tile-aligned points (tile_q * BLOCK_M and
kv_start), causing behavior to flip when BLOCK_M/BLOCK_N changes; fix it by
treating KV and Q tiles as token ranges and checking range-overlap with the
dense intervals instead of single-point comparisons: compute kv_tile_start =
kv_start, kv_tile_end = min(kv_start + BLOCK_N, seq_len_kv), q_abs_start =
tile_q * BLOCK_M + (seq_len_kv - seq_len_q), q_abs_end = min(q_abs_start +
BLOCK_M, seq_len_q), then return true if the KV tile range overlaps the sink
range [0, NUM_SINK_TOKENS) or the local window [q_abs_start - DENSE_WINDOW_SIZE
+ 1, q_abs_end) (use inclusive/exclusive semantics consistently); update
_is_dense_region to use these range-overlap checks so decisions are
token-accurate and invariant to BLOCK_M/BLOCK_N.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/llm_sparsity/attention_sparsity/hf_sa.py`:
- Around line 31-35: The CLI exposes SPARSE_SOFTMAX_DEFAULT which requires
Triton but main() later overrides per-layer backends from args.backend (default
"pytorch"), so invoking hf_sa.py --sparse_attn sparse_softmax yields an invalid
config; after argument parsing in main() detect if args.sparse_attn (or the
per-layer choice derived from SPARSE_SOFTMAX_DEFAULT) equals "sparse_softmax"
and either (a) set args.backend = "triton" automatically or (b) raise a clear
error instructing the user to pass --backend triton; update the same check site
that currently rewrites layers (the code that reads args.backend and applies
per-layer rewrites) so it enforces or auto-selects Triton for sparse_softmax.

---

Duplicate comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 308-321: The N:M sparsity branches currently run whenever
SPARSITY_N > 0, which lets external callers apply sparsity during decode; update
the public Triton kernel entry points to enforce the prefill-only contract by
combining the SPARSITY_N check with the same prefill predicate used by the HF
wrapper (i.e., only apply sparsification when this is a prefill call — the
decode-specific KV indices/lengths are not provided). Concretely, in _attn_fwd,
_attn_bwd_dq, and _attn_bwd_dkdv wrap the existing SPARSITY_N > 0 blocks with an
additional condition that b_start_loc_k and b_seq_len_kv indicate a prefill call
(for example, only run the sparse branch when b_start_loc_k and b_seq_len_kv are
unset/None or otherwise denote prefill), so cached decode KV paths are not
sparsified.
- Around line 161-185: _is_dense_region currently quantizes decisions to
tile-aligned points (tile_q * BLOCK_M and kv_start), causing behavior to flip
when BLOCK_M/BLOCK_N changes; fix it by treating KV and Q tiles as token ranges
and checking range-overlap with the dense intervals instead of single-point
comparisons: compute kv_tile_start = kv_start, kv_tile_end = min(kv_start +
BLOCK_N, seq_len_kv), q_abs_start = tile_q * BLOCK_M + (seq_len_kv - seq_len_q),
q_abs_end = min(q_abs_start + BLOCK_M, seq_len_q), then return true if the KV
tile range overlaps the sink range [0, NUM_SINK_TOKENS) or the local window
[q_abs_start - DENSE_WINDOW_SIZE + 1, q_abs_end) (use inclusive/exclusive
semantics consistently); update _is_dense_region to use these range-overlap
checks so decisions are token-accurate and invariant to BLOCK_M/BLOCK_N.

In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 228-238: The current validator validate_sparsity_n_vs_m only
checks n vs m and misses validating that method=="triton_sparse_softmax"
requires backend=="triton"; update that validator (or add a new
`@model_validator`(mode="after") in the same config class) to check self.method
(or the enum value for TritonSparseSoftmaxMethod) and raise a ValueError if
self.method == "triton_sparse_softmax" and self.backend != "triton", giving a
clear message like "triton_sparse_softmax requires backend='triton'"; keep this
validation in the config class so invalid combos are rejected during config
validation.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: dd3e47ce-a9c5-4cb5-a2ad-68b374dc46ea

📥 Commits

Reviewing files that changed from the base of the PR and between 67ae67b and ee10e20.

📒 Files selected for processing (12)
  • CHANGELOG.rst
  • examples/llm_sparsity/attention_sparsity/README.md
  • examples/llm_sparsity/attention_sparsity/hf_sa.py
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/methods/registry.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py
  • tests/gpu/torch/sparsity/attention_sparsity/conftest.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
✅ Files skipped from review due to trivial changes (3)
  • modelopt/torch/sparsity/attention_sparsity/methods/init.py
  • CHANGELOG.rst
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/sparsity/attention_sparsity/methods/registry.py

Signed-off-by: Kai Xu <kaix@nvidia.com>
@kaix-nv kaix-nv force-pushed the kaix/triton_fa_sparse24 branch from ee10e20 to 9fdeccd Compare March 26, 2026 03:32
@kaix-nv
Copy link
Contributor Author

kaix-nv commented Mar 26, 2026

@coderabbitai full review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 26, 2026

✅ Actions performed

Full review triggered.

@kaix-nv kaix-nv merged commit b1f9f01 into main Mar 26, 2026
45 checks passed
@kaix-nv kaix-nv deleted the kaix/triton_fa_sparse24 branch March 26, 2026 04:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants