[2/n] Add sparse softmax to the Triton flash attention kernel#1078
[2/n] Add sparse softmax to the Triton flash attention kernel#1078
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Caution Review failedPull request was closed or merged during review Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds Triton-backed N:M sparse-softmax to flash attention: kernel tile helpers and constexpr params, autograd and public API plumbing, sparsity method/config registration and HF prefill gating, plus GPU tests and a changelog entry. Changes
Sequence Diagram(s)sequenceDiagram
participant Caller as "Caller / Model"
participant API as "attention(...) / _Attention"
participant Autograd as "Autograd Function"
participant Triton as "Triton FA Kernel"
participant KV as "KV Storage / Tiles"
rect rgba(100,149,237,0.5)
Caller->>API: call attention(q,k,v,..., sparsity_n, sparsity_m, num_sink_tokens, dense_window_size)
API->>Autograd: _Attention.forward(..., sparsity params)
Autograd->>Triton: launch forward kernel (constexpr sparsity params)
Triton->>KV: load QK tile
Triton->>Triton: apply N:M mask (tile-level) -> set pruned scores to -inf
Triton->>Triton: softmax & compute output context
Triton-->>Autograd: return outputs + saved tensors (incl. sparsity params)
end
rect rgba(60,179,113,0.5)
Autograd->>Triton: backward launch with same sparsity params
Triton->>Triton: recompute masked scores (respect sink/window), compute dq/dk/dv
Triton-->>Autograd: return gradients
Autograd-->>Caller: propagate gradients
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1078 +/- ##
==========================================
- Coverage 70.21% 70.18% -0.03%
==========================================
Files 228 229 +1
Lines 25952 26008 +56
==========================================
+ Hits 18221 18254 +33
- Misses 7731 7754 +23 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
8ba6efe to
7aa6960
Compare
7aa6960 to
31655ce
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 279-292: The N:M sparsity branch is being applied during decode
because it only checks SPARSITY_N > 0; change the condition to skip
sparsification when doing cached decode (seq_len_q == 1). Update the if guarding
the block (the one that currently reads "if SPARSITY_N > 0:") to also require
not decoding (e.g., "if SPARSITY_N > 0 and seq_len_q != 1:") so
_apply_sparse_nm_to_qk_tile(scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M) is
only called during prefill/non-decoding paths; keep the existing local/sink
logic (kv_start, tile_q, q_abs_block, is_local/is_sink) unchanged.
- Around line 279-292: The sparse-mask logic currently mixes tile-sized units
(BLOCK_M/BLOCK_N) with sparsity parameters (NUM_SINK_BLOCKS,
DENSE_WINDOW_BLOCKS) leading to inconsistent masks between forward/backward; fix
by computing mask membership in token-space instead of tile-space: derive each
KV block index and query row absolute token position from actual token counts
(use seq_len_kv, seq_len_q, kv_start and per-row start = tile_q * BLOCK_M +
row_offset or for whole-tile use tile_token_start = tile_q * BLOCK_M) and then
map those token positions into logical token-blocks of a fixed reference block
size (choose the constant used by backward kernels, e.g., 64 tokens) before
comparing to NUM_SINK_BLOCKS and DENSE_WINDOW_BLOCKS; update q_abs_block,
kv_block_idx, is_sink and is_local computations (the branch that calls
_apply_sparse_nm_to_qk_tile) and apply same normalization in the other
occurrences mentioned (around the other two locations) so forward and backward
use the same token-blocking semantics.
In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 99-130: Add validation to the config so invalid N:M and negative
counts are rejected when building the config instead of failing in the Triton
kernel: in the class/constructor where ModeloptField(s) sparsity_n, sparsity_m,
num_sink_blocks, and dense_window_blocks are defined, validate that sparsity_m
is either 4 or 8, sparsity_n is in {1,2,3} when sparsity_m==4 and in {1,2,3,4}
when sparsity_m==8 (or 0/disabled as your semantics require), and that
num_sink_blocks and dense_window_blocks are non-negative; also ensure that the
chosen sparsity mode is only allowed when triton_sparse_softmax is selected or
available and raise a clear config validation error if any rule is violated so
bad configs fail early.
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py`:
- Around line 350-358: Add an explicit assertion after calling
mtsa.sparsify(...) to verify the Triton backend was applied: check that
model_sparse.config._attn_implementation == "modelopt_triton" (this should be
done right after mtsa.sparsify(...) returns and before comparing logits). Locate
the sparsification call (mtsa.sparsify(..., backend="triton", ...)) and add the
assertion referencing model_sparse.config._attn_implementation to ensure the
Triton kernel registration took effect.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: ce048172-0588-4aec-9474-44eb6c4cbe3b
📒 Files selected for processing (9)
CHANGELOG.rstmodelopt/torch/kernels/hf_triton_attention.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/__init__.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.pytests/gpu/torch/sparsity/attention_sparsity/conftest.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
✅ Files skipped from review due to trivial changes (2)
- modelopt/torch/sparsity/attention_sparsity/methods/init.py
- CHANGELOG.rst
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/gpu/torch/sparsity/attention_sparsity/conftest.py
- tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
| mask = torch.ones_like(attention_scores, dtype=torch.bool) | ||
| return mask, {} | ||
|
|
||
| def apply_sparsity(self, attention_scores, sparse_mask=None): |
There was a problem hiding this comment.
It would be better if, for Triton-backend attention sparsity methods, we don't have to add the boilerplate in this file (i.e. not have to add a new file to methods/ altogether)
There was a problem hiding this comment.
Good catch. I've made calculate_sparsity and apply_sparsity optional in base class.
|
@coderabbitai full review |
✅ Actions performedFull review triggered. |
|
There was a problem hiding this comment.
♻️ Duplicate comments (4)
modelopt/torch/kernels/triton_fa.py (4)
619-631:⚠️ Potential issue | 🟠 MajorBackward dK/dV: same tile-based locality issue applies here.
See comment on forward kernel (lines 279-292). The fix should be applied consistently here.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 619 - 631, The backward dK/dV sparse-N mask is missing the same tile-locality handling as the forward kernel; update the backward block (the code around where scores are processed in the backward pass) to perform the same token-level/tile locality checks (compute is_sink using kv_start and NUM_SINK_TOKENS, compute causal_offset = seq_len_kv - seq_len_q, q_abs_pos = qi * BLOCK_M + causal_offset, token_distance = q_abs_pos - kv_start, is_local = (token_distance >= 0) and (token_distance < DENSE_WINDOW_SIZE)) and only call _apply_sparse_nm_to_qk_tile(scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M) when not is_sink and not is_local, mirroring the forward kernel's logic so dK/dV uses identical N:M sparsity masking decisions.
280-280:⚠️ Potential issue | 🟠 MajorAdd decode guard — sparsity should be prefill-only per PR description.
The PR states "sparsity masking is applied during prefill only," but the kernel only checks
SPARSITY_N > 0. During decode (seq_len_q == 1), the KV cache will still be sparsified, which could degrade generation quality.🔧 Proposed fix
if SPARSITY_N > 0: + # Skip sparsity during decode (seq_len_q == 1) — apply only during prefill + is_decode = seq_len_q == 1 is_sink = kv_start < NUM_SINK_TOKENS causal_offset = seq_len_kv - seq_len_q q_abs_pos = tile_q * BLOCK_M + causal_offset token_distance = q_abs_pos - kv_start is_local = (token_distance >= 0) and (token_distance < DENSE_WINDOW_SIZE) - if not is_sink and not is_local: + if not is_decode and not is_sink and not is_local: scores = _apply_sparse_nm_to_qk_tile(...)Apply the same guard in
_attn_bwd_dq(line 481) and_attn_bwd_dkdv(line 621).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` at line 280, The sparsity mask is currently applied unconditionally when SPARSITY_N > 0; change the guard in the backward kernels _attn_bwd_dq and _attn_bwd_dkdv so sparsity is only applied during prefill (i.e., skip when decoding with seq_len_q == 1). Locate the existing "if SPARSITY_N > 0:" checks inside _attn_bwd_dq and _attn_bwd_dkdv and strengthen them to also require seq_len_q != 1 (for example: if SPARSITY_N > 0 and seq_len_q != 1) so the KV cache is not sparsified during decode.
480-491:⚠️ Potential issue | 🟠 MajorBackward dQ: same tile-based locality issue applies here.
See comment on forward kernel (lines 279-292). The fix should be applied consistently here.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 480 - 491, This backward dQ block repeats the same tile-locality bug as the forward kernel; replicate the forward-kernel fix (the block-based locality check used around lines 279-292) here: compute the tile/block-level positions using tile_q, BLOCK_M and the corresponding kv tile start (using tile_k or kv_start aligned to BLOCK_N), derive q_abs_pos and token_distance at block granularity, set is_local based on the block/window bounds (and preserve the is_sink check with NUM_SINK_TOKENS), and only call _apply_sparse_nm_to_qk_tile(scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M) when the tile is neither sink nor local—i.e., exactly mirror the forward kernel's locality logic and conditions so forward and backward agree.
279-292:⚠️ Potential issue | 🟠 MajorForward/backward sparse mask mismatch due to autotuned vs fixed tile sizes.
q_abs_pos = tile_q * BLOCK_M + causal_offsetdepends on BLOCK_M, but forward autotunes BLOCK_M over {64, 128} while backward hardcodesBLOCK = 64. For the same query row, the computedq_abs_posand thusis_localcan differ between forward and backward, causing gradient mismatch.For example, query position 70 with causal_offset=0:
- Forward (BLOCK_M=128): tile_q=0, q_abs_pos=0
- Backward (BLOCK_M=64): tile_q=1, q_abs_pos=64
This changes whether tiles fall within
DENSE_WINDOW_SIZE, applying different sparse masks.Consider computing locality per-row (using the actual row offset within the tile) rather than per-tile:
if SPARSITY_N > 0: is_sink = kv_start < NUM_SINK_TOKENS causal_offset = seq_len_kv - seq_len_q - q_abs_pos = tile_q * BLOCK_M + causal_offset - token_distance = q_abs_pos - kv_start - is_local = (token_distance >= 0) and (token_distance < DENSE_WINDOW_SIZE) - if not is_sink and not is_local: + # Per-row locality: check if ANY row in this Q tile is within dense window + q_tile_start = tile_q * BLOCK_M + causal_offset + q_tile_end = q_tile_start + BLOCK_M - 1 + # Tile overlaps dense window if any Q row is within DENSE_WINDOW_SIZE of kv_start + tile_overlaps_window = (q_tile_start - kv_start < DENSE_WINDOW_SIZE) and (q_tile_start >= kv_start) + if not is_sink and not tile_overlaps_window: scores = _apply_sparse_nm_to_qk_tile(...)Alternatively, fix BLOCK_M to match backward's block size when sparsity is enabled.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 279 - 292, The sparse/locality check uses q_abs_pos = tile_q * BLOCK_M which diverges between forward (autotuned BLOCK_M) and backward (fixed BLOCK=64); fix by computing locality per-row instead of per-tile: derive the absolute query row index as (tile_q * BLOCK_M + row_within_tile) where row_within_tile is the actual row offset inside the current tile (from the loop/index that produces scores), then use that q_abs_pos for is_local/is_sink logic before calling _apply_sparse_nm_to_qk_tile; alternatively, if you prefer the simpler change, force BLOCK_M to the backward block size (64) whenever SPARSITY_N > 0 so forward and backward use the same tile size.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 619-631: The backward dK/dV sparse-N mask is missing the same
tile-locality handling as the forward kernel; update the backward block (the
code around where scores are processed in the backward pass) to perform the same
token-level/tile locality checks (compute is_sink using kv_start and
NUM_SINK_TOKENS, compute causal_offset = seq_len_kv - seq_len_q, q_abs_pos = qi
* BLOCK_M + causal_offset, token_distance = q_abs_pos - kv_start, is_local =
(token_distance >= 0) and (token_distance < DENSE_WINDOW_SIZE)) and only call
_apply_sparse_nm_to_qk_tile(scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M)
when not is_sink and not is_local, mirroring the forward kernel's logic so dK/dV
uses identical N:M sparsity masking decisions.
- Line 280: The sparsity mask is currently applied unconditionally when
SPARSITY_N > 0; change the guard in the backward kernels _attn_bwd_dq and
_attn_bwd_dkdv so sparsity is only applied during prefill (i.e., skip when
decoding with seq_len_q == 1). Locate the existing "if SPARSITY_N > 0:" checks
inside _attn_bwd_dq and _attn_bwd_dkdv and strengthen them to also require
seq_len_q != 1 (for example: if SPARSITY_N > 0 and seq_len_q != 1) so the KV
cache is not sparsified during decode.
- Around line 480-491: This backward dQ block repeats the same tile-locality bug
as the forward kernel; replicate the forward-kernel fix (the block-based
locality check used around lines 279-292) here: compute the tile/block-level
positions using tile_q, BLOCK_M and the corresponding kv tile start (using
tile_k or kv_start aligned to BLOCK_N), derive q_abs_pos and token_distance at
block granularity, set is_local based on the block/window bounds (and preserve
the is_sink check with NUM_SINK_TOKENS), and only call
_apply_sparse_nm_to_qk_tile(scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M)
when the tile is neither sink nor local—i.e., exactly mirror the forward
kernel's locality logic and conditions so forward and backward agree.
- Around line 279-292: The sparse/locality check uses q_abs_pos = tile_q *
BLOCK_M which diverges between forward (autotuned BLOCK_M) and backward (fixed
BLOCK=64); fix by computing locality per-row instead of per-tile: derive the
absolute query row index as (tile_q * BLOCK_M + row_within_tile) where
row_within_tile is the actual row offset inside the current tile (from the
loop/index that produces scores), then use that q_abs_pos for is_local/is_sink
logic before calling _apply_sparse_nm_to_qk_tile; alternatively, if you prefer
the simpler change, force BLOCK_M to the backward block size (64) whenever
SPARSITY_N > 0 so forward and backward use the same tile size.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 515c68d0-8207-4d6f-9009-6aa727c66189
📒 Files selected for processing (6)
modelopt/torch/kernels/hf_triton_attention.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
🚧 Files skipped from review as they are similar to previous changes (3)
- modelopt/torch/kernels/hf_triton_attention.py
- modelopt/torch/sparsity/attention_sparsity/config.py
- tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
There was a problem hiding this comment.
🧹 Nitpick comments (4)
tests/gpu/torch/sparsity/attention_sparsity/conftest.py (1)
1-1: Update copyright year to 2026.The license header has copyright year 2024, but the current year is 2026. Consider updating for consistency with the PR date.
📝 Suggested fix
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gpu/torch/sparsity/attention_sparsity/conftest.py` at line 1, Update the SPDX license header year from 2024 to 2026 at the top of the conftest.py file (the file-level comment starting with "SPDX-FileCopyrightText") so the copyright line reflects the current year.modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py (1)
1-1: Update copyright year to 2026.The license header has copyright year 2024, but this is a new file created in 2026.
📝 Suggested fix
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py` at line 1, Update the SPDX copyright header in the file that currently reads "Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES" to use the year 2026; locate the header at the top of the file (the SPDX/FileCopyrightText comment) and change the year to 2026 so the license header reflects the file creation year.tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py (2)
232-243: Clarify assertion expectation for M=8 pattern.The test parametrizes over both M=4 and M=8 patterns, but line 241's assertion
(kept == n).all()expects exactly N kept. For M=8 withtl.sort-based thresholding, ties may keep ≥N elements. While random inputs (line 236) make ties unlikely, consider adding a comment or splitting the test.The separate
test_sparsity_structure_tiesat lines 250-264 correctly handles this distinction, so this is a minor documentation concern.📝 Suggested clarification
def test_sparsity_structure(self, n, m): - """Verify N:M structure: exactly N kept per group of M.""" + """Verify N:M structure: exactly N kept per group of M (random input avoids ties).""" bm, bn = 32, 64 torch.manual_seed(88) tile = torch.randn(bm, bn, device="cuda", dtype=torch.float32)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py` around lines 232 - 243, test_sparsity_structure currently asserts exactly N kept per group with (kept == n).all() but for M=8 the tl.sort-based thresholding can legitimately keep >=N on ties; update the test to either relax the check for m==8 (use (kept >= n).all() when m == 8) or add a clear comment next to the assertion explaining that ties for M=8 may produce >=N and that test_sparsity_structure_ties covers strict tie behavior; reference test_sparsity_structure, the kept variable, and the call to _test_apply_sparse_nm so reviewers can find and apply the change.
1-1: Update copyright year to 2026.New file should have current copyright year.
📝 Suggested fix
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py` at line 1, Update the SPDX copyright header line that currently reads "Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES." to the current year 2026; locate the SPDX header (the line beginning with "SPDX-FileCopyrightText") in tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py and replace 2024 with 2026 so the SPDX header reflects Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py`:
- Line 1: Update the SPDX copyright header in the file that currently reads
"Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES" to use the year 2026;
locate the header at the top of the file (the SPDX/FileCopyrightText comment)
and change the year to 2026 so the license header reflects the file creation
year.
In `@tests/gpu/torch/sparsity/attention_sparsity/conftest.py`:
- Line 1: Update the SPDX license header year from 2024 to 2026 at the top of
the conftest.py file (the file-level comment starting with
"SPDX-FileCopyrightText") so the copyright line reflects the current year.
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py`:
- Around line 232-243: test_sparsity_structure currently asserts exactly N kept
per group with (kept == n).all() but for M=8 the tl.sort-based thresholding can
legitimately keep >=N on ties; update the test to either relax the check for
m==8 (use (kept >= n).all() when m == 8) or add a clear comment next to the
assertion explaining that ties for M=8 may produce >=N and that
test_sparsity_structure_ties covers strict tie behavior; reference
test_sparsity_structure, the kept variable, and the call to
_test_apply_sparse_nm so reviewers can find and apply the change.
- Line 1: Update the SPDX copyright header line that currently reads "Copyright
(c) 2024 NVIDIA CORPORATION & AFFILIATES." to the current year 2026; locate the
SPDX header (the line beginning with "SPDX-FileCopyrightText") in
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py and
replace 2024 with 2026 so the SPDX header reflects Copyright (c) 2026 NVIDIA
CORPORATION & AFFILIATES.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: dac2310d-e99e-47c1-b62c-962ab622bdd4
📒 Files selected for processing (9)
CHANGELOG.rstmodelopt/torch/kernels/hf_triton_attention.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/__init__.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.pytests/gpu/torch/sparsity/attention_sparsity/conftest.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
9882dbb to
67ae67b
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (2)
modelopt/torch/sparsity/attention_sparsity/methods/registry.py (1)
84-94: Consider adding@abstractmethoddecorator for consistency.
get_sparse_contextraisesNotImplementedErrorbut lacks the@abstractmethoddecorator, unlike thenameproperty (line 109). This inconsistency means subclasses won't get a clear error at instantiation time if they forget to implement it—they'll only fail at runtime when the method is called.✨ Suggested change
+ `@abstractmethod` def get_sparse_context(self, module: torch.nn.Module): """Return a context manager that activates this method's sparsity during forward. Each method subclass implements its own activation mechanism: - Softmax-patching methods replace F.softmax during the forward pass. - Kernel-fused methods set flags on ``module`` that the kernel reads. Args: module: The SparseAttentionModule wrapping the attention layer. """ - raise NotImplementedError(f"{type(self).__name__} must implement get_sparse_context()")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/methods/registry.py` around lines 84 - 94, get_sparse_context currently raises NotImplementedError but isn't decorated as abstract, causing subclasses to only error at runtime; mark get_sparse_context with the `@abstractmethod` decorator (same style used for the name property) so subclasses of the registry base class must implement get_sparse_context (which should return a context manager for activating sparsity on the SparseAttentionModule) and ensure imports/ABC usage are consistent with the existing abstract methods.modelopt/torch/kernels/triton_fa.py (1)
866-871: Consider adding input validation for sparsity parameters.The kernel enforces constraints via
tl.static_assert(SPARSITY_M must be 4 or 8), but invalid combinations at the Python API level could produce unexpected behavior. For example,sparsity_n >= sparsity_morsparsity_n < 0won't be caught until kernel compilation.✨ Suggested validation
def attention( q: torch.Tensor, ... *, sparsity_n: int = 0, sparsity_m: int = 4, num_sink_tokens: int = 0, dense_window_size: int = 64, ) -> torch.Tensor: + if sparsity_n < 0: + raise ValueError(f"sparsity_n must be non-negative, got {sparsity_n}") + if sparsity_n > 0: + if sparsity_m not in (4, 8): + raise ValueError(f"sparsity_m must be 4 or 8, got {sparsity_m}") + if sparsity_n >= sparsity_m: + raise ValueError(f"sparsity_n ({sparsity_n}) must be < sparsity_m ({sparsity_m})") ...🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 866 - 871, Add Python-level input validation at the start of the function that accepts sparsity_n and sparsity_m (the function with parameters sparsity_n: int = 0, sparsity_m: int = 4, num_sink_tokens, dense_window_size) to prevent invalid combos before kernel compilation: check that sparsity_m is one of the supported values (4 or 8), sparsity_n is an int >= 0 and strictly less than sparsity_m, and that both are integers; if any check fails, raise a ValueError with a clear message referencing sparsity_n/sparsity_m so users see the invalid values. Ensure these checks run before any tl.static_assert or kernel compilation logic so invalid inputs are caught early.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 866-871: Add Python-level input validation at the start of the
function that accepts sparsity_n and sparsity_m (the function with parameters
sparsity_n: int = 0, sparsity_m: int = 4, num_sink_tokens, dense_window_size) to
prevent invalid combos before kernel compilation: check that sparsity_m is one
of the supported values (4 or 8), sparsity_n is an int >= 0 and strictly less
than sparsity_m, and that both are integers; if any check fails, raise a
ValueError with a clear message referencing sparsity_n/sparsity_m so users see
the invalid values. Ensure these checks run before any tl.static_assert or
kernel compilation logic so invalid inputs are caught early.
In `@modelopt/torch/sparsity/attention_sparsity/methods/registry.py`:
- Around line 84-94: get_sparse_context currently raises NotImplementedError but
isn't decorated as abstract, causing subclasses to only error at runtime; mark
get_sparse_context with the `@abstractmethod` decorator (same style used for the
name property) so subclasses of the registry base class must implement
get_sparse_context (which should return a context manager for activating
sparsity on the SparseAttentionModule) and ensure imports/ABC usage are
consistent with the existing abstract methods.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 8083fa36-d78b-4fde-9ad9-9ea3c8eb1666
📒 Files selected for processing (10)
CHANGELOG.rstmodelopt/torch/kernels/hf_triton_attention.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/__init__.pymodelopt/torch/sparsity/attention_sparsity/methods/registry.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.pytests/gpu/torch/sparsity/attention_sparsity/conftest.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
✅ Files skipped from review due to trivial changes (3)
- CHANGELOG.rst
- tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
- tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
🚧 Files skipped from review as they are similar to previous changes (4)
- modelopt/torch/kernels/hf_triton_attention.py
- modelopt/torch/sparsity/attention_sparsity/methods/init.py
- modelopt/torch/sparsity/attention_sparsity/config.py
- modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py
Edwardf0t1
left a comment
There was a problem hiding this comment.
LGTM overall, left some comments.
|
|
||
| # N:M sparse softmax — prefill only (decode should not sparsify KV) | ||
| if not is_decode and getattr(module, "_apply_sparse_nm", False): | ||
| method = getattr(module, "_sparse_method_instance", None) |
There was a problem hiding this comment.
Where _sparse_method_instance gets set, if it's outside this PR, please add a comment.
modelopt/torch/kernels/triton_fa.py
Outdated
| is_sink = kv_start < NUM_SINK_TOKENS | ||
| # causal_offset handles chunked prefill: q starts at (seq_len_kv - seq_len_q) | ||
| causal_offset = seq_len_kv - seq_len_q | ||
| q_abs_pos = tile_q * BLOCK_M + causal_offset | ||
| token_distance = q_abs_pos - kv_start | ||
| is_local = (token_distance >= 0) and (token_distance < DENSE_WINDOW_SIZE) | ||
| if not is_sink and not is_local: | ||
| scores = _apply_sparse_nm_to_qk_tile( | ||
| scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M | ||
| ) |
There was a problem hiding this comment.
Consider extracting the duplicated sink/window check into a shared helper function.
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
67ae67b to
ee10e20
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (3)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
228-238:⚠️ Potential issue | 🟠 MajorValidate
triton_sparse_softmaxagainstbackend="triton"here.This validator still accepts
{"method": "triton_sparse_softmax", "backend": "pytorch"}. The example CLI can create exactly that config, andTritonSparseSoftmaxMethodrelies on the fused Triton path rather than a PyTorch implementation, so this should fail during config validation instead of later.🛠️ Suggested validation
`@model_validator`(mode="after") def validate_sparsity_n_vs_m(self): """Validate sparsity_n is within the supported range for the given sparsity_m.""" + if self.method == "triton_sparse_softmax" and self.backend != "triton": + raise ValueError("triton_sparse_softmax requires backend='triton'") if self.sparsity_n > 0: max_n = 3 if self.sparsity_m == 4 else self.sparsity_m - 1 if self.sparsity_n > max_n:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 228 - 238, The current validator validate_sparsity_n_vs_m only checks n vs m and misses validating that method=="triton_sparse_softmax" requires backend=="triton"; update that validator (or add a new `@model_validator`(mode="after") in the same config class) to check self.method (or the enum value for TritonSparseSoftmaxMethod) and raise a ValueError if self.method == "triton_sparse_softmax" and self.backend != "triton", giving a clear message like "triton_sparse_softmax requires backend='triton'"; keep this validation in the config class so invalid combos are rejected during config validation.modelopt/torch/kernels/triton_fa.py (2)
308-321:⚠️ Potential issue | 🟠 MajorKeep the prefill-only contract inside the public Triton API too.
These branches only gate on
SPARSITY_N > 0, so direct calls toattention(..., b_start_loc_k=..., b_seq_len_k=..., sparsity_n>0)will sparsify cached decode KV even though the feature is documented as prefill-only. The HF wrapper avoids passing these kwargs on decode, but the public kernel API now exposes them directly, so forward and both backward recompute paths should enforce the same guard.🛠️ Suggested guard
- if SPARSITY_N > 0: + is_decode = seq_len_q == 1 and seq_len_kv > 1 + if SPARSITY_N > 0 and not is_decode:Apply the same predicate in
_attn_fwd,_attn_bwd_dq, and_attn_bwd_dkdv.Also applies to: 509-523, 651-665
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 308 - 321, The N:M sparsity branches currently run whenever SPARSITY_N > 0, which lets external callers apply sparsity during decode; update the public Triton kernel entry points to enforce the prefill-only contract by combining the SPARSITY_N check with the same prefill predicate used by the HF wrapper (i.e., only apply sparsification when this is a prefill call — the decode-specific KV indices/lengths are not provided). Concretely, in _attn_fwd, _attn_bwd_dq, and _attn_bwd_dkdv wrap the existing SPARSITY_N > 0 blocks with an additional condition that b_start_loc_k and b_seq_len_kv indicate a prefill call (for example, only run the sparse branch when b_start_loc_k and b_seq_len_kv are unset/None or otherwise denote prefill), so cached decode KV paths are not sparsified.
161-185:⚠️ Potential issue | 🔴 CriticalThe “token” dense-region knobs are still quantized by launch tiles.
_is_dense_region()runs once per(q_tile, kv_tile)and only looks atkv_startandtile_q * BLOCK_M. Sonum_sink_tokens=4keeps the entire first KV tile dense (32/64/128 tokens depending onBLOCK_N), not the first 4 tokens, and the local-window decision changes when forward autotunesBLOCK_M/BLOCK_Nwhile backward recomputes with fixed 64x64 tiles. For example, q[64:127] sharestile_q=0in forward whenBLOCK_M=128, but both backward kernels recompute it astile_q=1, sodense_window_size=64flipskv_start=64from sparse to dense. That makes backward differentiate a different mask than forward.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 161 - 185, _is_dense_region currently quantizes decisions to tile-aligned points (tile_q * BLOCK_M and kv_start), causing behavior to flip when BLOCK_M/BLOCK_N changes; fix it by treating KV and Q tiles as token ranges and checking range-overlap with the dense intervals instead of single-point comparisons: compute kv_tile_start = kv_start, kv_tile_end = min(kv_start + BLOCK_N, seq_len_kv), q_abs_start = tile_q * BLOCK_M + (seq_len_kv - seq_len_q), q_abs_end = min(q_abs_start + BLOCK_M, seq_len_q), then return true if the KV tile range overlaps the sink range [0, NUM_SINK_TOKENS) or the local window [q_abs_start - DENSE_WINDOW_SIZE + 1, q_abs_end) (use inclusive/exclusive semantics consistently); update _is_dense_region to use these range-overlap checks so decisions are token-accurate and invariant to BLOCK_M/BLOCK_N.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/llm_sparsity/attention_sparsity/hf_sa.py`:
- Around line 31-35: The CLI exposes SPARSE_SOFTMAX_DEFAULT which requires
Triton but main() later overrides per-layer backends from args.backend (default
"pytorch"), so invoking hf_sa.py --sparse_attn sparse_softmax yields an invalid
config; after argument parsing in main() detect if args.sparse_attn (or the
per-layer choice derived from SPARSE_SOFTMAX_DEFAULT) equals "sparse_softmax"
and either (a) set args.backend = "triton" automatically or (b) raise a clear
error instructing the user to pass --backend triton; update the same check site
that currently rewrites layers (the code that reads args.backend and applies
per-layer rewrites) so it enforces or auto-selects Triton for sparse_softmax.
---
Duplicate comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 308-321: The N:M sparsity branches currently run whenever
SPARSITY_N > 0, which lets external callers apply sparsity during decode; update
the public Triton kernel entry points to enforce the prefill-only contract by
combining the SPARSITY_N check with the same prefill predicate used by the HF
wrapper (i.e., only apply sparsification when this is a prefill call — the
decode-specific KV indices/lengths are not provided). Concretely, in _attn_fwd,
_attn_bwd_dq, and _attn_bwd_dkdv wrap the existing SPARSITY_N > 0 blocks with an
additional condition that b_start_loc_k and b_seq_len_kv indicate a prefill call
(for example, only run the sparse branch when b_start_loc_k and b_seq_len_kv are
unset/None or otherwise denote prefill), so cached decode KV paths are not
sparsified.
- Around line 161-185: _is_dense_region currently quantizes decisions to
tile-aligned points (tile_q * BLOCK_M and kv_start), causing behavior to flip
when BLOCK_M/BLOCK_N changes; fix it by treating KV and Q tiles as token ranges
and checking range-overlap with the dense intervals instead of single-point
comparisons: compute kv_tile_start = kv_start, kv_tile_end = min(kv_start +
BLOCK_N, seq_len_kv), q_abs_start = tile_q * BLOCK_M + (seq_len_kv - seq_len_q),
q_abs_end = min(q_abs_start + BLOCK_M, seq_len_q), then return true if the KV
tile range overlaps the sink range [0, NUM_SINK_TOKENS) or the local window
[q_abs_start - DENSE_WINDOW_SIZE + 1, q_abs_end) (use inclusive/exclusive
semantics consistently); update _is_dense_region to use these range-overlap
checks so decisions are token-accurate and invariant to BLOCK_M/BLOCK_N.
In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 228-238: The current validator validate_sparsity_n_vs_m only
checks n vs m and misses validating that method=="triton_sparse_softmax"
requires backend=="triton"; update that validator (or add a new
`@model_validator`(mode="after") in the same config class) to check self.method
(or the enum value for TritonSparseSoftmaxMethod) and raise a ValueError if
self.method == "triton_sparse_softmax" and self.backend != "triton", giving a
clear message like "triton_sparse_softmax requires backend='triton'"; keep this
validation in the config class so invalid combos are rejected during config
validation.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: dd3e47ce-a9c5-4cb5-a2ad-68b374dc46ea
📒 Files selected for processing (12)
CHANGELOG.rstexamples/llm_sparsity/attention_sparsity/README.mdexamples/llm_sparsity/attention_sparsity/hf_sa.pymodelopt/torch/kernels/hf_triton_attention.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/__init__.pymodelopt/torch/sparsity/attention_sparsity/methods/registry.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.pytests/gpu/torch/sparsity/attention_sparsity/conftest.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
✅ Files skipped from review due to trivial changes (3)
- modelopt/torch/sparsity/attention_sparsity/methods/init.py
- CHANGELOG.rst
- tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
🚧 Files skipped from review as they are similar to previous changes (2)
- modelopt/torch/kernels/hf_triton_attention.py
- modelopt/torch/sparsity/attention_sparsity/methods/registry.py
Signed-off-by: Kai Xu <kaix@nvidia.com>
ee10e20 to
9fdeccd
Compare
|
@coderabbitai full review |
✅ Actions performedFull review triggered. |
What does this PR do?
Type of change: ?
Type of change: New feature
Add N:M structured sparsity support to the Triton flash attention kernel (
modelopt/torch/kernels/triton_fa.py). For every M consecutive key positions in the attention score tile, keeps the top-N values and sets the rest to -inf before softmax. This is applied during prefill only.Supported patterns: Any N:M where M=4 (N=1,2,3) or M=8 (N=1..4).
Performance (TFLOPS at seq_len=16384, RTX 6000):
Usage
Testing
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
New Features
API
Configuration
Tests
Documentation