Skip to content

[3/n] Add skip-softmax to Triton flash attention kernel#1081

Open
kaix-nv wants to merge 5 commits intomainfrom
kaix/triton_fa_skip_softmax
Open

[3/n] Add skip-softmax to Triton flash attention kernel#1081
kaix-nv wants to merge 5 commits intomainfrom
kaix/triton_fa_skip_softmax

Conversation

@kaix-nv
Copy link
Contributor

@kaix-nv kaix-nv commented Mar 20, 2026

What does this PR do?

Type of change: ?

New feature. Add skip-softmax tile skipping to the Triton flash attention kernel.

Usage

# Add a code snippet demonstrating how to use this
from modelopt.torch.kernels import attention

# Skip-softmax with threshold 0.1 (tiles contributing < 10% are skipped)
out = attention(q, k, v, b_start_loc, b_seq_len, max_len,
                skip_softmax_threshold=0.1)

# Via mtsa.sparsify() on HuggingFace models
import modelopt.torch.sparsity.attention_sparsity as mtsa
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B",
                                              torch_dtype=torch.bfloat16,
                                              device_map="cuda")

# Default config
mtsa.sparsify(model, mtsa.SKIP_SOFTMAX_TRITON_DEFAULT)

Testing

Performance (TFLOPS at seq_len=16384, RTX 6000 Pro):

SEQ_LEN ModelOpt Triton PyTorch SDPA Flash Attention 2 Skip-Softmax t=0.01 Skip-Softmax t=0.1
16384.0 188.849922 211.718193 224.242843 172.901804 279.861684
32768.0 175.321787 212.815740 224.833553 146.150702 262.490463
65536.0 167.302839 214.932407 226.456141 145.082937 243.344791

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Added a Triton "skip-softmax" tile-skipping option for flash attention with a new attention keyword and configurable threshold (default 0.1).
    • Added a new sparse attention method and a default sparse configuration that enables the Triton skip-softmax method.
  • Tests

    • Added GPU tests covering threshold behavior, numerical fidelity vs dense, shape preservation, decode-mode, and integration with sparsify.
  • Documentation

    • Updated changelog for the new feature and removed two prior listed entries.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 20, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 20, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a Triton-side skip-softmax tile-skipping optimization to flash attention, exposes a runtime skip_softmax_threshold kwarg in the attention API, integrates the method into sparsity config/registration and the HF wrapper, updates autograd/backward to reapply skips, and adds GPU tests and a changelog entry.

Changes

Cohort / File(s) Summary
Triton FA kernel
modelopt/torch/kernels/triton_fa.py
Adds APPLY_SKIP_SOFTMAX and SKIP_THRESHOLD_LOG2 constexprs; implements per-row/tile max checks to optionally skip softmax/V work in forward; backward kernels reapply skip mask; _Attention and public attention(...) API accept and propagate skip_softmax_threshold.
HF attention wrapper
modelopt/torch/kernels/hf_triton_attention.py
triton_attention_forward conditionally injects skip_softmax_threshold into Triton kernel kwargs when module._apply_skip_softmax and the method instance threshold are present; no other metadata or post-processing changed.
Sparsity config & defaults
modelopt/torch/sparsity/attention_sparsity/config.py
Adds SparseAttentionAttributeConfig.skip_softmax_threshold: float = 0.1 and new preset SKIP_SOFTMAX_TRITON_DEFAULT targeting *attn* with method="triton_skip_softmax", backend="triton". Exports updated.
Sparsity methods pkg init
modelopt/torch/sparsity/attention_sparsity/methods/__init__.py
Imports triton_skip_softmax at package init so the method registers on import (side-effect registration; no public API change).
New sparse method
modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
Adds TritonSkipSoftmaxMethod (registered as triton_skip_softmax) that reads skip_softmax_threshold from config and provides a context manager toggling module._apply_skip_softmax during forward scope.
Methods registry
modelopt/torch/sparsity/attention_sparsity/methods/registry.py
Makes calculate_sparsity concrete (returns all-True mask and empty stats) and apply_sparsity concrete (raises NotImplementedError by default) instead of abstract methods.
Tests & changelog
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py, CHANGELOG.rst
Adds CUDA+Triton tests validating disabled/zero/low/high thresholds, decode-path behavior, monotonic error checks, and a HF integration test using SKIP_SOFTMAX_TRITON_DEFAULT; updates changelog to document the Triton skip-softmax tile-skipping entry and removes two prior entries from the 0.44 new-features list.

Sequence Diagram(s)

mermaid
sequenceDiagram
participant HF as HF wrapper
participant Sparsity as Sparsity config/method
participant Triton as Triton FA kernel
participant Autograd as Autograd ctx
HF->>Sparsity: query sparse method / skip_softmax_threshold
Sparsity-->>HF: return threshold (or None)
HF->>Triton: call attention(..., skip_softmax_threshold=val)
Triton->>Triton: compute tile_row_max and per-row can_skip
alt all rows skippable for tile
Triton-->>Triton: skip softmax update; skip V load/accumulation
else some rows not skippable
Triton-->>Triton: perform online softmax update; accumulate V (zero skipped rows)
end
Triton-->>Autograd: store skip flags on ctx
Autograd->>Triton: backward launch; reapply skip mask to gradients

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main change: adding skip-softmax tile skipping functionality to the Triton flash attention kernel, which is the primary focus across all modified files.
Docstring Coverage ✅ Passed Docstring coverage is 80.00% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed PR introduces skip-softmax optimization for Triton flash attention kernel with no security anti-patterns detected. Code adheres to security best practices with no unsafe operations.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch kaix/triton_fa_skip_softmax

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Mar 20, 2026

Codecov Report

❌ Patch coverage is 47.82609% with 12 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.19%. Comparing base (291498b) to head (c49bca2).

Files with missing lines Patch % Lines
.../attention_sparsity/methods/triton_skip_softmax.py 42.10% 11 Missing ⚠️
...ch/sparsity/attention_sparsity/methods/registry.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1081      +/-   ##
==========================================
- Coverage   70.21%   70.19%   -0.02%     
==========================================
  Files         228      229       +1     
  Lines       25952    25976      +24     
==========================================
+ Hits        18221    18233      +12     
- Misses       7731     7743      +12     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kaix-nv kaix-nv changed the title Add skip-softmax tile skipping to Triton flash attention kernel Add skip-softmax to Triton flash attention kernel Mar 21, 2026
@kaix-nv kaix-nv force-pushed the kaix/triton_fa_skip_softmax branch from 9225466 to cc0e9b3 Compare March 21, 2026 00:51
@kaix-nv kaix-nv marked this pull request as ready for review March 21, 2026 21:33
@kaix-nv kaix-nv requested a review from a team as a code owner March 21, 2026 21:33
@kaix-nv kaix-nv force-pushed the kaix/triton_fa_skip_softmax branch from cc0e9b3 to 270b94e Compare March 21, 2026 21:33
@kaix-nv kaix-nv requested a review from a team as a code owner March 21, 2026 21:33
@kaix-nv kaix-nv force-pushed the kaix/triton_fa_skip_softmax branch from 270b94e to 6c65ef3 Compare March 21, 2026 21:35
@kaix-nv kaix-nv requested review from rohansjoshi and removed request for shengliangxu March 21, 2026 21:35
@kaix-nv kaix-nv force-pushed the kaix/triton_fa_skip_softmax branch from 6c65ef3 to 012fb20 Compare March 21, 2026 21:37
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py (1)

494-505: Avoid asserting monotonic MAE from random samples.

Lines 500-505 assume that increasing the threshold must increase mean(abs(out_skip - out_dense)), but that is not guaranteed; extra skipped tiles can still reduce the final error through cancellation on a fixed input. This is likely to be flaky across seeds and GPU/dtype combinations. Prefer a directly monotonic signal, or weaken the expectation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py` around lines
494 - 505, The test `test_monotonic_approximation_error` assumes mean absolute
error increases strictly with skip_softmax_threshold, which is flaky; change the
assertion to a weaker, robust check: compute errors for thresholds via
attention(q,k,v,locs,lens,512,softmax_scale=scale,skip_softmax_threshold=...),
then either remove the strict stepwise monotonic assertion and instead assert a
single inequality between the smallest and largest thresholds with a tolerance
(e.g., errors[0] <= errors[-1] + tol) or allow small per-step regressions by
checking non-decrease within a small relative/absolute tolerance; update the
final assert accordingly and keep references to the variables/functions used
(attention, out_dense, out_skip, errors, skip_softmax_threshold).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 380-384: The backward pass is recomputing the skip mask from final
lse which can differ from the forward per-tile running row_max; instead persist
the exact forward skip decisions (or the pre-tile row_max used in forward) so
the backward replays them exactly: modify the forward path that computes
tile_row_max / can_skip (used when APPLY_SKIP_SOFTMAX) to store the boolean skip
mask (or the pre-tile max) alongside tensors needed for backward and have the
backward use that saved mask when zeroing p (rather than recomputing can_skip
from lse and SKIP_THRESHOLD_LOG2); as a short-term alternative, gate
APPLY_SKIP_SOFTMAX to inference-only until you add this saved metadata so
gradients remain correct.

In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 99-107: The skip_softmax_threshold field must be validated to
ensure it is a fraction in [0, 1]; update the config parsing/validation so
negative values or values >1 raise during parse rather than silently changing
kernel behavior. Modify the typed config that defines skip_softmax_threshold
(the ModeloptField) to enforce 0.0 <= skip_softmax_threshold <= 1.0 — either by
adding a pydantic validator for skip_softmax_threshold or adding an explicit
check in the config class constructor/__post_init__ that raises a ValueError
with a clear message if the constraint is violated. Ensure the error triggers
during config parse/instantiation so callers get immediate feedback.

---

Nitpick comments:
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py`:
- Around line 494-505: The test `test_monotonic_approximation_error` assumes
mean absolute error increases strictly with skip_softmax_threshold, which is
flaky; change the assertion to a weaker, robust check: compute errors for
thresholds via
attention(q,k,v,locs,lens,512,softmax_scale=scale,skip_softmax_threshold=...),
then either remove the strict stepwise monotonic assertion and instead assert a
single inequality between the smallest and largest thresholds with a tolerance
(e.g., errors[0] <= errors[-1] + tol) or allow small per-step regressions by
checking non-decrease within a small relative/absolute tolerance; update the
final assert accordingly and keep references to the variables/functions used
(attention, out_dense, out_skip, errors, skip_softmax_threshold).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 4a8c1dda-739d-4a3e-b939-e729f5e6858d

📥 Commits

Reviewing files that changed from the base of the PR and between 08e5f92 and 270b94eb73c2d1ce98f0ca3e7e478e51c1ef342f.

📒 Files selected for processing (6)
  • CHANGELOG.rst
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/__init__.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py

Comment on lines +99 to +107
skip_softmax_threshold: float = ModeloptField(
default=0.1,
title="Skip-softmax threshold.",
description=(
"Tiles contributing less than this fraction are skipped entirely. "
"Only used by triton_skip_softmax. Typical values: 1e-3 to 1e-1. "
"Set to 0 to disable."
),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Validate skip_softmax_threshold in the typed config.

Line 99 introduces a public fraction, but negative or >1 values currently pass validation and change kernel behavior in surprising ways. Reject them at parse time instead of silently treating them as “disabled” or “skip almost everything.”

🧩 Suggested constraint
     skip_softmax_threshold: float = ModeloptField(
         default=0.1,
         title="Skip-softmax threshold.",
         description=(
             "Tiles contributing less than this fraction are skipped entirely. "
             "Only used by triton_skip_softmax. Typical values: 1e-3 to 1e-1. "
             "Set to 0 to disable."
         ),
+        ge=0.0,
+        le=1.0,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 99 - 107,
The skip_softmax_threshold field must be validated to ensure it is a fraction in
[0, 1]; update the config parsing/validation so negative values or values >1
raise during parse rather than silently changing kernel behavior. Modify the
typed config that defines skip_softmax_threshold (the ModeloptField) to enforce
0.0 <= skip_softmax_threshold <= 1.0 — either by adding a pydantic validator for
skip_softmax_threshold or adding an explicit check in the config class
constructor/__post_init__ that raises a ValueError with a clear message if the
constraint is violated. Ensure the error triggers during config
parse/instantiation so callers get immediate feedback.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
modelopt/torch/kernels/triton_fa.py (1)

380-384: ⚠️ Potential issue | 🔴 Critical

Backward still cannot replay the forward skip decisions.

ctx only saves the scalar skip flag/threshold, so these backward kernels rebuild can_skip from final lse instead of the pre-tile row_max used in forward. Since lse is always at least as large as the forward running max, backward can zero gradients for tiles that were kept in forward. Please either persist the exact forward mask / pre-tile max or keep skip_softmax_threshold inference-only until backward can replay the same predicate. The public docstring should not claim “the same skip decision” until this is fixed.

🛡️ Safe short-term guard
         apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
+        if apply_skip and (q.requires_grad or k.requires_grad or v.requires_grad):
+            raise NotImplementedError(
+                "skip_softmax_threshold is inference-only until backward can replay "
+                "the exact forward skip decisions."
+            )
         if apply_skip:
             import math

Also applies to: 510-514, 627-628

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 380 - 384, The backward
kernels are recomputing the skip predicate from lse, which differs from the
forward pre-tile max and causes incorrect gradient zeroing; change the forward
pass to save the exact per-tile skip mask or the pre-tile row_max into ctx (not
just the scalar skip_softmax_threshold) and have the backward kernels (the code
paths using APPLY_SKIP_SOFTMAX where can_skip is computed) read that saved
mask/value from ctx to reconstruct the exact same can_skip used in forward;
alternatively, make skip_softmax_threshold inference-only until backward can
replay the same predicate and update the public docstring to stop claiming “the
same skip decision” until fixed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 568-575: Validate the skip_softmax_threshold value before
computing skip_threshold_log2: treat only None or 0.0 as disabled, and raise a
ValueError for NaN, infinite, negative, or >1 values (accept only values in the
open interval (0, 1] for enabling). Update the logic around apply_skip,
skip_softmax_threshold, and skip_threshold_log2 to perform this check and raise
early with a clear message, and apply the same validation to the other
occurrence of the same pattern in this file (the block around the second
occurrence noted in the comment).

---

Duplicate comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 380-384: The backward kernels are recomputing the skip predicate
from lse, which differs from the forward pre-tile max and causes incorrect
gradient zeroing; change the forward pass to save the exact per-tile skip mask
or the pre-tile row_max into ctx (not just the scalar skip_softmax_threshold)
and have the backward kernels (the code paths using APPLY_SKIP_SOFTMAX where
can_skip is computed) read that saved mask/value from ctx to reconstruct the
exact same can_skip used in forward; alternatively, make skip_softmax_threshold
inference-only until backward can replay the same predicate and update the
public docstring to stop claiming “the same skip decision” until fixed.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: fa0b5612-c9b2-47f7-a1bf-cb211e19a57e

📥 Commits

Reviewing files that changed from the base of the PR and between 270b94eb73c2d1ce98f0ca3e7e478e51c1ef342f and 012fb20.

📒 Files selected for processing (6)
  • CHANGELOG.rst
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/__init__.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
✅ Files skipped from review due to trivial changes (1)
  • CHANGELOG.rst
🚧 Files skipped from review as they are similar to previous changes (4)
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/sparsity/attention_sparsity/methods/init.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py

Comment on lines +568 to +575
# Skip-softmax: convert threshold to log2 space for the kernel
apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
if apply_skip:
import math

skip_threshold_log2 = math.log2(skip_softmax_threshold)
else:
skip_threshold_log2 = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Reject invalid skip_softmax_threshold values up front.

This knob is documented as a contribution fraction, but the host-side parsing currently accepts nan, inf, negatives, and values above 1. That means a typo can either silently disable the feature or make later tiles overly skippable. Please reserve None/0 as the only disable cases and raise on anything outside (0, 1].

🧪 Proposed fix
-        apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
-        if apply_skip:
-            import math
-
-            skip_threshold_log2 = math.log2(skip_softmax_threshold)
+        import math
+
+        if skip_softmax_threshold is None or skip_softmax_threshold == 0.0:
+            apply_skip = False
+        else:
+            if not math.isfinite(skip_softmax_threshold) or not (0.0 < skip_softmax_threshold <= 1.0):
+                raise ValueError(
+                    "skip_softmax_threshold must be a finite float in (0, 1], or None/0 to disable."
+                )
+            apply_skip = True
+
+        if apply_skip:
+            skip_threshold_log2 = math.log2(skip_softmax_threshold)
         else:
             skip_threshold_log2 = 0.0

Also applies to: 762-768

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 568 - 575, Validate the
skip_softmax_threshold value before computing skip_threshold_log2: treat only
None or 0.0 as disabled, and raise a ValueError for NaN, infinite, negative, or
>1 values (accept only values in the open interval (0, 1] for enabling). Update
the logic around apply_skip, skip_softmax_threshold, and skip_threshold_log2 to
perform this check and raise early with a clear message, and apply the same
validation to the other occurrence of the same pattern in this file (the block
around the second occurrence noted in the comment).

@kaix-nv kaix-nv changed the title Add skip-softmax to Triton flash attention kernel [3/n] Add skip-softmax to Triton flash attention kernel Mar 23, 2026
@kaix-nv kaix-nv force-pushed the kaix/triton_fa_skip_softmax branch from 012fb20 to ecc5540 Compare March 23, 2026 22:35
@github-actions
Copy link
Contributor

github-actions bot commented Mar 23, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1081/

Built to branch gh-pages at 2026-03-26 02:14 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (3)
modelopt/torch/kernels/triton_fa.py (2)

380-384: ⚠️ Potential issue | 🔴 Critical

Do not enable skip_softmax_threshold during training yet.

Forward skips against the pre-tile running row_max, but these backward paths rebuild can_skip from final lse. That can zero gradients for tiles that were not skipped in forward, so training with this flag is still incorrect. Please either persist the exact forward skip mask / pre-tile max or gate this mode to inference only.

🛡️ Safe short-term guard
         apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
+        if apply_skip and (q.requires_grad or k.requires_grad or v.requires_grad):
+            raise NotImplementedError(
+                "skip_softmax_threshold is inference-only until backward can replay "
+                "the exact forward skip decisions."
+            )
         if apply_skip:
             import math

Also applies to: 510-514

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 380 - 384, The backward
code recomputes can_skip from lse and SKIP_THRESHOLD_LOG2 which can differ from
the forward decision (tile_row_max) causing incorrect zeroed gradients when
APPLY_SKIP_SOFTMAX (skip_softmax_threshold) is enabled; fix by either (A)
persisting the exact forward skip mask (compute and store tile_row_max and/or
can_skip from the forward pass and reuse that mask in the backward path when
restoring p) or (B) disallowing this mode during training by gating
APPLY_SKIP_SOFTMAX/skip_softmax_threshold to inference-only; update logic that
references tile_row_max, can_skip, scores, lse and p to use the persisted mask
or the inference-only guard accordingly.

568-575: ⚠️ Potential issue | 🟡 Minor

Validate skip_softmax_threshold before computing log2.

Only None and 0 are documented disable cases. Negative, non-finite, or >1 values currently either get silently treated as off or make skipping much more aggressive than the API contract suggests.

🧪 Suggested input validation
-        # Skip-softmax: convert threshold to log2 space for the kernel
-        apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
-        if apply_skip:
-            import math
-
-            skip_threshold_log2 = math.log2(skip_softmax_threshold)
-        else:
-            skip_threshold_log2 = 0.0
+        # Skip-softmax: convert threshold to log2 space for the kernel
+        import math
+
+        if skip_softmax_threshold is None or skip_softmax_threshold == 0.0:
+            apply_skip = False
+            skip_threshold_log2 = 0.0
+        else:
+            if not math.isfinite(skip_softmax_threshold) or not (0.0 < skip_softmax_threshold <= 1.0):
+                raise ValueError(
+                    "skip_softmax_threshold must be a finite float in (0, 1], or None/0 to disable."
+                )
+            apply_skip = True
+            skip_threshold_log2 = math.log2(skip_softmax_threshold)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 568 - 575, Validate
skip_softmax_threshold before computing log2: when computing apply_skip and
skip_threshold_log2 (use the variables skip_softmax_threshold, apply_skip,
skip_threshold_log2 and the math.log2 call), ensure that if
skip_softmax_threshold is not None it is a finite numeric value and within the
documented range (0 < value <= 1); treat 0 or None as “off”; for values that are
negative, non-finite (NaN/inf) or >1 raise a clear ValueError (or TypeError for
wrong type) with a message explaining allowed values so the code never silently
treats invalid inputs as off or miscomputes the log2.
modelopt/torch/sparsity/attention_sparsity/config.py (1)

99-107: ⚠️ Potential issue | 🟡 Minor

Validate skip_softmax_threshold during config parsing.

This new public fraction still accepts negatives, non-finite values, and values above 1, which makes the Triton path either silently disable skipping or skip far too aggressively. Reject invalid values when the config is instantiated instead of relying on runtime behavior.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 99 - 107,
The public field skip_softmax_threshold can be negative, non-finite, or >1; add
validation at config instantiation so invalid values are rejected early: in the
config class that defines skip_softmax_threshold (the class using
ModeloptField), implement a validation step (e.g. __post_init__ or a
pydantic/ModeloptField validator) that checks the value is finite and 0.0 <=
skip_softmax_threshold <= 1.0 and raise a ValueError with a clear message if
not; this ensures invalid inputs are caught when the config is created rather
than at runtime in triton_skip_softmax.
🧹 Nitpick comments (1)
modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py (1)

65-71: Restore the previous module flag in finally.

This context manager always writes False on exit, so nested or stacked uses on the same module can clobber an outer active context. Restoring the prior value makes the activation state composable.

♻️ Suggested fix
         `@contextmanager`
         def _skip_softmax_context():
+            prev_apply_skip_softmax = getattr(module, "_apply_skip_softmax", False)
             module._apply_skip_softmax = True
             try:
                 yield
             finally:
-                module._apply_skip_softmax = False
+                module._apply_skip_softmax = prev_apply_skip_softmax
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`
around lines 65 - 71, The _skip_softmax_context context manager currently
overwrites module._apply_skip_softmax to False on exit, which breaks nested
contexts; modify _skip_softmax_context to save the prior value (e.g., prev =
module._apply_skip_softmax), set module._apply_skip_softmax = True on entry, and
in the finally block restore module._apply_skip_softmax = prev so nested or
stacked uses of the context preserve outer states (apply this change inside the
_skip_softmax_context definition).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 380-384: The backward code recomputes can_skip from lse and
SKIP_THRESHOLD_LOG2 which can differ from the forward decision (tile_row_max)
causing incorrect zeroed gradients when APPLY_SKIP_SOFTMAX
(skip_softmax_threshold) is enabled; fix by either (A) persisting the exact
forward skip mask (compute and store tile_row_max and/or can_skip from the
forward pass and reuse that mask in the backward path when restoring p) or (B)
disallowing this mode during training by gating
APPLY_SKIP_SOFTMAX/skip_softmax_threshold to inference-only; update logic that
references tile_row_max, can_skip, scores, lse and p to use the persisted mask
or the inference-only guard accordingly.
- Around line 568-575: Validate skip_softmax_threshold before computing log2:
when computing apply_skip and skip_threshold_log2 (use the variables
skip_softmax_threshold, apply_skip, skip_threshold_log2 and the math.log2 call),
ensure that if skip_softmax_threshold is not None it is a finite numeric value
and within the documented range (0 < value <= 1); treat 0 or None as “off”; for
values that are negative, non-finite (NaN/inf) or >1 raise a clear ValueError
(or TypeError for wrong type) with a message explaining allowed values so the
code never silently treats invalid inputs as off or miscomputes the log2.

In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 99-107: The public field skip_softmax_threshold can be negative,
non-finite, or >1; add validation at config instantiation so invalid values are
rejected early: in the config class that defines skip_softmax_threshold (the
class using ModeloptField), implement a validation step (e.g. __post_init__ or a
pydantic/ModeloptField validator) that checks the value is finite and 0.0 <=
skip_softmax_threshold <= 1.0 and raise a ValueError with a clear message if
not; this ensures invalid inputs are caught when the config is created rather
than at runtime in triton_skip_softmax.

---

Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`:
- Around line 65-71: The _skip_softmax_context context manager currently
overwrites module._apply_skip_softmax to False on exit, which breaks nested
contexts; modify _skip_softmax_context to save the prior value (e.g., prev =
module._apply_skip_softmax), set module._apply_skip_softmax = True on entry, and
in the finally block restore module._apply_skip_softmax = prev so nested or
stacked uses of the context preserve outer states (apply this change inside the
_skip_softmax_context definition).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b3835fd4-3e45-4467-a16f-8477c8ba3c2c

📥 Commits

Reviewing files that changed from the base of the PR and between 012fb20 and ecc5540.

📒 Files selected for processing (7)
  • CHANGELOG.rst
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
✅ Files skipped from review due to trivial changes (2)
  • CHANGELOG.rst
  • modelopt/torch/sparsity/attention_sparsity/methods/init.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py

@kaix-nv kaix-nv requested review from Edwardf0t1 and jingyu-ml March 24, 2026 00:16
m_new = tl.maximum(row_max, tile_row_max)
p = tl.math.exp2(scores - m_new[:, None])
# Zero out skipped rows (instead of masking scores and recomputing max)
p = tl.where(can_skip[:, None], 0.0, p)
Copy link
Contributor

@jingyu-ml jingyu-ml Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here's confusing, if not all_skip, we don't do the skip for this tile, then why we add the p = tl.where(can_skip[:, None], 0.0, p) to here since we skip nothing for this tile, and we are doing the tile level skipping

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess all_skip means that, for this specific tile, it’s true if we skip it and false otherwise. If that’s the case, the variable name is confusing as well._

Copy link
Contributor

@jingyu-ml jingyu-ml Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems to me, this line should be deleted

Copy link
Contributor

@Edwardf0t1 Edwardf0t1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, left some comments.

# Re-apply skip-softmax: zero out rows that were skipped in forward
if APPLY_SKIP_SOFTMAX:
tile_row_max = tl.max(scores, 1)
can_skip = tile_row_max < (lse + SKIP_THRESHOLD_LOG2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So forward and backward is different — forward skips based on row_max, backward on lse.

The lse >= row_max always holds, so the backward threshold is strictly looser — it will skip fewer tiles than the forward. This means the backward may compute gradients for tiles that were skipped in the forward.

Is this intentional?

Comment on lines +167 to +214
if APPLY_SKIP_SOFTMAX:
# --- Skip-softmax path: check tile, skip V load if all rows negligible ---
# Compute tile row max once — reused for both skip check and softmax update
tile_row_max = tl.max(scores, 1) # [BLOCK_M]
can_skip = tile_row_max < (row_max + SKIP_THRESHOLD_LOG2)
all_skip = tl.min(can_skip.to(tl.int32)) == 1

if not all_skip:
# Online softmax update (reuses tile_row_max — no second tl.max)
# For skipped rows: tile_row_max < row_max, so m_new = row_max (no change)
m_new = tl.maximum(row_max, tile_row_max)
p = tl.math.exp2(scores - m_new[:, None])
# Zero out skipped rows (instead of masking scores and recomputing max)
p = tl.where(can_skip[:, None], 0.0, p)
l_new = tl.sum(p, 1)
correction = tl.math.exp2(row_max - m_new)
row_sum = row_sum * correction + l_new
acc = acc * correction[:, None]

# Load V and accumulate
v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :]
v = tl.load(
v_base + v_offs,
mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :],
other=0.0,
)
acc = tl.dot(p.to(v.dtype), v, acc)
row_max = m_new
# else: all rows negligible — skip V load, softmax update, accumulation
else:
# --- Standard path: no skip check ---
# Online softmax update
m_new = tl.maximum(row_max, tl.max(scores, 1))
p = tl.math.exp2(scores - m_new[:, None])
l_new = tl.sum(p, 1)
correction = tl.math.exp2(row_max - m_new)
row_sum = row_sum * correction + l_new
acc = acc * correction[:, None]

# Load V and accumulate
v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :]
v = tl.load(
v_base + v_offs,
mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :],
other=0.0,
)
acc = tl.dot(p.to(v.dtype), v, acc)
row_max = m_new
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems one path for APPLY_SKIP_SOFTMAX=True and one for False. The standard path and the "not all_skip" branch inside the skip path are nearly identical. Since APPLY_SKIP_SOFTMAX is a tl.constexpr, Triton will compile-time eliminate the dead branch, but the source code duplication makes maintenance harder.

Consider structuring as:

# Skip check (compiled out when APPLY_SKIP_SOFTMAX=False)                                                                                                                                                                                                               
  do_process = True                                                                                                                                                                                                                                                       
  if APPLY_SKIP_SOFTMAX:                                                                                                                                                                                                                                                  
      tile_row_max = tl.max(scores, 1)                                                                                                                                                                                                                                    
      can_skip = tile_row_max < (row_max + SKIP_THRESHOLD_LOG2)                                                                                                                                                                                                           
      all_skip = tl.min(can_skip.to(tl.int32)) == 1                                                                                                                                                                                                                       
      if all_skip:                                                                                                                                                                                                                                                        
          do_process = False                                                                                                                                                                                                                                              
                                                                                                                                                                                                                                                                          
  if do_process:                                                                                                                                                                                                                                                          
      # Single copy of softmax update + V accumulation                                                                                                                                                                                                                    
      ...                                                                                                                                                                                                                                                                 
      if APPLY_SKIP_SOFTMAX:                                                                                                                                                                                                                                              
          p = tl.where(can_skip[:, None], 0.0, p)                                                                                                                                                                                                                         
      ...   

@kaix-nv
Copy link
Contributor Author

kaix-nv commented Mar 25, 2026

@coderabbitai full review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 25, 2026

✅ Actions performed

Full review triggered.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (4)
modelopt/torch/kernels/triton_fa.py (2)

568-575: ⚠️ Potential issue | 🟠 Major

Reject invalid skip_softmax_threshold values up front.

Direct callers can bypass the config path entirely, and this currently accepts negatives, non-finite values, and values above 1. In particular, inf will make every post-first tile trivially skippable and can drive the kernel into NaNs. Please reserve None/0 as the disable cases and raise on anything outside finite (0, 1].

🧪 Suggested fix
-        apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
-        if apply_skip:
-            import math
-
-            skip_threshold_log2 = math.log2(skip_softmax_threshold)
-        else:
-            skip_threshold_log2 = 0.0
+        import math
+
+        if skip_softmax_threshold is None or skip_softmax_threshold == 0.0:
+            apply_skip = False
+            skip_threshold_log2 = 0.0
+        else:
+            if not math.isfinite(skip_softmax_threshold) or not (
+                0.0 < skip_softmax_threshold <= 1.0
+            ):
+                raise ValueError(
+                    "skip_softmax_threshold must be a finite float in (0, 1], "
+                    "or None/0 to disable."
+                )
+            apply_skip = True
+            skip_threshold_log2 = math.log2(skip_softmax_threshold)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 568 - 575, Validate
skip_softmax_threshold up front: treat only None or 0.0 as disable, otherwise
require a finite number in the range (0, 1]; use math.isfinite and raise
ValueError for negatives, zero (except explicit 0.0 disable), non-finite
(inf/NaN), or >1. Then compute skip_threshold_log2 as
math.log2(skip_softmax_threshold) and set apply_skip based on that validation
(use the existing symbols skip_softmax_threshold, apply_skip,
skip_threshold_log2).

380-384: ⚠️ Potential issue | 🔴 Critical

Backward is not replaying the forward skip decisions.

Forward decides can_skip from the per-tile running row_max, but backward recomputes it from final lse. Since lse is not the same state and is always at least as large as the pre-tile max, these branches can zero gradients for tiles that were actually kept in forward. That makes skip_softmax_threshold gradient-incorrect whenever any input requires grad. Please either save the exact forward skip mask / pre-tile max for backward, or gate this mode to inference-only until that metadata exists.

🛡️ Safe short-term guard
         if apply_skip:
             import math

             skip_threshold_log2 = math.log2(skip_softmax_threshold)
         else:
             skip_threshold_log2 = 0.0
+
+        if apply_skip and (q.requires_grad or k.requires_grad or v.requires_grad):
+            raise NotImplementedError(
+                "skip_softmax_threshold is inference-only until backward can replay "
+                "the exact forward skip decisions."
+            )

Also applies to: 510-514

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 380 - 384, The backward
pass is recomputing the skip decision from `lse` (using `SKIP_THRESHOLD_LOG2`)
which differs from the forward per-tile `tile_row_max`, causing incorrect
gradients; fix by persisting the exact forward skip decision (save `can_skip` or
`tile_row_max` from the forward path) and use that saved mask in the backward to
zero out `p` (instead of recomputing from `lse`), or, until that metadata is
stored, disable `APPLY_SKIP_SOFTMAX` for any inputs that require gradients (gate
the optimization to inference-only); update the forward to stash the mask (or
pre-tile max) and update the backward to read and apply that saved mask when
processing `p`.
modelopt/torch/sparsity/attention_sparsity/config.py (1)

99-107: ⚠️ Potential issue | 🟡 Minor

Validate skip_softmax_threshold during config parse.

This is a public fraction field, but negatives and values above 1 still validate here and only fail later as confusing kernel behavior. Please reject anything outside [0.0, 1.0] at parse time.

🧪 Suggested fix
     skip_softmax_threshold: float = ModeloptField(
         default=0.1,
         title="Skip-softmax threshold.",
         description=(
             "Tiles contributing less than this fraction are skipped entirely. "
             "Only used by triton_skip_softmax. Typical values: 1e-3 to 1e-1. "
             "Set to 0 to disable."
         ),
+        ge=0.0,
+        le=1.0,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 99 - 107,
Add validation for skip_softmax_threshold in the config parsing so values
outside [0.0, 1.0] are rejected early: in the code that constructs/parses the
config (the ModeloptField definition for skip_softmax_threshold or the config
class's validation/__post_init__/validate method), check the
skip_softmax_threshold value and raise a clear ValueError (or use the config
validation mechanism) if skip_softmax_threshold < 0.0 or skip_softmax_threshold
> 1.0, ensuring any negative or >1 inputs fail at parse time rather than later
in the Triton kernel.
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py (1)

562-600: ⚠️ Potential issue | 🟡 Minor

This integration test still doesn’t prove skip-softmax was enabled.

The case is intentionally short and then asserts logits_skip ~= logits_dense, so it still passes if triton_attention_forward() never forwards skip_softmax_threshold or if the threshold never causes any tile to be skipped on this model. Please force a multi-tile prompt / larger tiny-model fixture, or assert that the kwarg reaches attention() directly.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py` around lines
562 - 600, The test currently never verifies that skip-softmax was actually
enabled; update test_skip_softmax_via_sparsify to either (A) force multi-tile
behavior by using a longer input (increase ids length beyond a single
tile/sequence chunk) or set SKIP_SOFTMAX_TRITON_DEFAULT to a value that will
trigger skipping, and then check outputs, or (B) directly assert the kwarg is
forwarded by monkeypatching/wrapping the attention implementation (wrap
triton_attention_forward or the model's attention() method instances obtained
from the loaded AutoModelForCausalLM) to capture its kwargs and assert
skip_softmax_threshold (or a boolean like skip_softmax) is present and set;
reference functions/classes: test_skip_softmax_via_sparsify, mtsa.sparsify,
mtsa.SKIP_SOFTMAX_TRITON_DEFAULT, triton_attention_forward, attention().
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py`:
- Around line 494-505: Remove the unstable total ordering assertion in
test_monotonic_approximation_error that requires errors[0] <= errors[1] <=
errors[2]; instead keep deterministic sanity checks: ensure each computed error
is finite and non-negative (e.g., errors[i] is not NaN and errors[i] >= 0) and
replace the strict chain with a single small-vs-large check between the smallest
and largest threshold (e.g., errors[0] <= errors[2]) to validate that very small
thresholds produce no larger error than very large thresholds; update references
to attention, skip_softmax_threshold, out_dense, out_skip, and errors
accordingly.

---

Duplicate comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 568-575: Validate skip_softmax_threshold up front: treat only None
or 0.0 as disable, otherwise require a finite number in the range (0, 1]; use
math.isfinite and raise ValueError for negatives, zero (except explicit 0.0
disable), non-finite (inf/NaN), or >1. Then compute skip_threshold_log2 as
math.log2(skip_softmax_threshold) and set apply_skip based on that validation
(use the existing symbols skip_softmax_threshold, apply_skip,
skip_threshold_log2).
- Around line 380-384: The backward pass is recomputing the skip decision from
`lse` (using `SKIP_THRESHOLD_LOG2`) which differs from the forward per-tile
`tile_row_max`, causing incorrect gradients; fix by persisting the exact forward
skip decision (save `can_skip` or `tile_row_max` from the forward path) and use
that saved mask in the backward to zero out `p` (instead of recomputing from
`lse`), or, until that metadata is stored, disable `APPLY_SKIP_SOFTMAX` for any
inputs that require gradients (gate the optimization to inference-only); update
the forward to stash the mask (or pre-tile max) and update the backward to read
and apply that saved mask when processing `p`.

In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 99-107: Add validation for skip_softmax_threshold in the config
parsing so values outside [0.0, 1.0] are rejected early: in the code that
constructs/parses the config (the ModeloptField definition for
skip_softmax_threshold or the config class's validation/__post_init__/validate
method), check the skip_softmax_threshold value and raise a clear ValueError (or
use the config validation mechanism) if skip_softmax_threshold < 0.0 or
skip_softmax_threshold > 1.0, ensuring any negative or >1 inputs fail at parse
time rather than later in the Triton kernel.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py`:
- Around line 562-600: The test currently never verifies that skip-softmax was
actually enabled; update test_skip_softmax_via_sparsify to either (A) force
multi-tile behavior by using a longer input (increase ids length beyond a single
tile/sequence chunk) or set SKIP_SOFTMAX_TRITON_DEFAULT to a value that will
trigger skipping, and then check outputs, or (B) directly assert the kwarg is
forwarded by monkeypatching/wrapping the attention implementation (wrap
triton_attention_forward or the model's attention() method instances obtained
from the loaded AutoModelForCausalLM) to capture its kwargs and assert
skip_softmax_threshold (or a boolean like skip_softmax) is present and set;
reference functions/classes: test_skip_softmax_via_sparsify, mtsa.sparsify,
mtsa.SKIP_SOFTMAX_TRITON_DEFAULT, triton_attention_forward, attention().

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: cfbbdb8e-2c07-4cbd-8692-fb6b141d1661

📥 Commits

Reviewing files that changed from the base of the PR and between c425524 and ecc5540.

📒 Files selected for processing (7)
  • CHANGELOG.rst
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py

Comment on lines +494 to +505
def test_monotonic_approximation_error(self):
"""Larger threshold -> larger error vs dense (monotonic degradation)."""
q, k, v, locs, lens = self._make_inputs(seq_len=512)
scale = 1.0 / (64**0.5)
out_dense = attention(q, k, v, locs, lens, 512, softmax_scale=scale)
errors = []
for threshold in [1e-4, 1e-2, 1e-1]:
out_skip = attention(
q, k, v, locs, lens, 512, softmax_scale=scale, skip_softmax_threshold=threshold
)
errors.append((out_skip - out_dense).abs().mean().item())
assert errors[0] <= errors[1] <= errors[2], f"Errors not monotonic: {errors}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

errors[0] <= errors[1] <= errors[2] is not a stable correctness invariant.

Higher thresholds skip more work, but the output error vs. dense can still decrease because the dropped V contributions can cancel. That makes this assertion a potential flake on a correct kernel. I’d keep the existing small/large-threshold sanity checks and drop this total ordering unless you can assert a more direct monotonic signal.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py` around lines
494 - 505, Remove the unstable total ordering assertion in
test_monotonic_approximation_error that requires errors[0] <= errors[1] <=
errors[2]; instead keep deterministic sanity checks: ensure each computed error
is finite and non-negative (e.g., errors[i] is not NaN and errors[i] >= 0) and
replace the strict chain with a single small-vs-large check between the smallest
and largest threshold (e.g., errors[0] <= errors[2]) to validate that very small
thresholds produce no larger error than very large thresholds; update references
to attention, skip_softmax_threshold, out_dense, out_skip, and errors
accordingly.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
modelopt/torch/kernels/triton_fa.py (1)

169-190: ⚠️ Potential issue | 🔴 Critical

Backward is still replaying a different skip rule than forward.

Forward only skips when skip_tile is true, using the per-tile running row_max. Backward rebuilds can_skip from final lse and then zeros rows individually, so mixed tiles and later tiles can drop gradients for work that was actually done in forward. Please save the exact forward tile-skip decision (or the pre-tile row_max) and replay that, or keep this mode inference-only.

🛡️ Safe short-term guard
         apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
+        if apply_skip and (q.requires_grad or k.requires_grad or v.requires_grad):
+            raise NotImplementedError(
+                "skip_softmax_threshold is inference-only until backward can replay "
+                "the exact forward tile-skip decisions."
+            )
         if apply_skip:
             skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale

Also applies to: 393-397, 523-527

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 169 - 190, Forward computes
tile_row_max, can_skip and skip_tile using row_max and SKIP_THRESHOLD_LOG2 but
the backward recomputes can_skip from lse causing mismatch; to fix, capture and
store the exact forward decision (e.g., a boolean per-tile mask like
skip_tile_mask or the pre-tile row_max values computed by tile_row_max) inside
the forward path (near tile_row_max / can_skip / skip_tile) and have the
backward path replay that saved mask/value instead of recomputing from lse so
gradients are only zeroed for tiles actually skipped in forward (apply same
change to the other occurrences referenced around lines 393-397 and 523-527).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 581-590: The skip-softmax threshold is being over-scaled: when
computing skip_threshold_log2 inside the function handling
skip_softmax_threshold, do not multiply math.log2(skip_softmax_threshold) by
sm_scale because scores are already in qk_scale (raw_score * sm_scale * LOG2E)
space; set skip_threshold_log2 = math.log2(skip_softmax_threshold) when
apply_skip is true (leave the else branch as 0.0), and keep references to
skip_softmax_threshold, skip_threshold_log2, sm_scale, qk_scale, and scores to
locate and update the calculation.

In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`:
- Around line 51-57: _the _skip_softmax_context context manager currently
unconditionally sets module._apply_skip_softmax = False on exit, which breaks
nested contexts; modify _skip_softmax_context to save the previous value (e.g.,
old = getattr(module, "_apply_skip_softmax", False)) before setting it to True,
and in the finally block restore module._apply_skip_softmax = old so
nested/re-entrant uses of _skip_softmax_context correctly preserve outer state._

---

Duplicate comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 169-190: Forward computes tile_row_max, can_skip and skip_tile
using row_max and SKIP_THRESHOLD_LOG2 but the backward recomputes can_skip from
lse causing mismatch; to fix, capture and store the exact forward decision
(e.g., a boolean per-tile mask like skip_tile_mask or the pre-tile row_max
values computed by tile_row_max) inside the forward path (near tile_row_max /
can_skip / skip_tile) and have the backward path replay that saved mask/value
instead of recomputing from lse so gradients are only zeroed for tiles actually
skipped in forward (apply same change to the other occurrences referenced around
lines 393-397 and 523-527).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 6ca1822c-2d5d-4782-ba9d-52393ce9a916

📥 Commits

Reviewing files that changed from the base of the PR and between ecc5540 and 7c966b1.

📒 Files selected for processing (4)
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/methods/registry.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/kernels/hf_triton_attention.py

Comment on lines +581 to +590
# Skip-softmax: convert threshold to scaled log2 space for the kernel.
# The BLASST reference (https://arxiv.org/pdf/2512.12087) checks
# ln(lambda) on unscaled scores. Our kernel works in log2-scaled space
# (scores pre-multiplied by qk_scale = sm_scale * LOG2E), so we
# pre-scale: threshold_scaled = log2(lambda) * sm_scale.
apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
if apply_skip:
skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale
else:
skip_threshold_log2 = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Don't multiply the skip cutoff by sm_scale again.

scores already live in raw_score * sm_scale * log2(e) space via qk_scale, so the contribution cutoff in that same space is just log2(lambda). Multiplying by sm_scale here makes the effective threshold head-dim dependent and much looser than the documented fraction.

🐛 Proposed fix
-        # pre-scale: threshold_scaled = log2(lambda) * sm_scale.
+        # In kernel space, score deltas already include `sm_scale * log2(e)`,
+        # so the contribution cutoff is simply `log2(lambda)`.
         apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
         if apply_skip:
-            skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale
+            skip_threshold_log2 = math.log2(skip_softmax_threshold)
         else:
             skip_threshold_log2 = 0.0
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 581 - 590, The skip-softmax
threshold is being over-scaled: when computing skip_threshold_log2 inside the
function handling skip_softmax_threshold, do not multiply
math.log2(skip_softmax_threshold) by sm_scale because scores are already in
qk_scale (raw_score * sm_scale * LOG2E) space; set skip_threshold_log2 =
math.log2(skip_softmax_threshold) when apply_skip is true (leave the else branch
as 0.0), and keep references to skip_softmax_threshold, skip_threshold_log2,
sm_scale, qk_scale, and scores to locate and update the calculation.

Comment on lines +51 to +57
@contextmanager
def _skip_softmax_context():
module._apply_skip_softmax = True
try:
yield
finally:
module._apply_skip_softmax = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Restore the previous _apply_skip_softmax state on exit.

The finally block always writes False, so a nested/re-entrant context on the same module will clear an outer skip-softmax scope too early. Save the old value and restore it instead.

💡 Proposed fix
         `@contextmanager`
         def _skip_softmax_context():
-            module._apply_skip_softmax = True
+            prev_apply_skip_softmax = getattr(module, "_apply_skip_softmax", False)
+            module._apply_skip_softmax = True
             try:
                 yield
             finally:
-                module._apply_skip_softmax = False
+                module._apply_skip_softmax = prev_apply_skip_softmax
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`
around lines 51 - 57, _the _skip_softmax_context context manager currently
unconditionally sets module._apply_skip_softmax = False on exit, which breaks
nested contexts; modify _skip_softmax_context to save the previous value (e.g.,
old = getattr(module, "_apply_skip_softmax", False)) before setting it to True,
and in the finally block restore module._apply_skip_softmax = old so
nested/re-entrant uses of _skip_softmax_context correctly preserve outer state._

@kaix-nv
Copy link
Contributor Author

kaix-nv commented Mar 26, 2026

@coderabbitai full review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 26, 2026

✅ Actions performed

Full review triggered.

kaix-nv added 4 commits March 25, 2026 17:58
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
@kaix-nv kaix-nv force-pushed the kaix/triton_fa_skip_softmax branch from 7c966b1 to 59849a1 Compare March 26, 2026 00:58
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (3)
modelopt/torch/sparsity/attention_sparsity/config.py (1)

99-107: ⚠️ Potential issue | 🟡 Minor

Add range validation for skip_softmax_threshold.

The field accepts any float, but values outside [0, 1] are invalid (negative values or >1 would cause unexpected kernel behavior). Add Pydantic constraints to reject invalid values at parse time.

🧩 Suggested constraint
     skip_softmax_threshold: float = ModeloptField(
         default=0.1,
         title="Skip-softmax threshold.",
         description=(
             "Tiles contributing less than this fraction are skipped entirely. "
             "Only used by triton_skip_softmax. Typical values: 1e-3 to 1e-1. "
             "Set to 0 to disable."
         ),
+        ge=0.0,
+        le=1.0,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 99 - 107,
Add validation to skip_softmax_threshold so values are constrained to [0,1]:
update the ModeloptField declaration for skip_softmax_threshold in config.py
(the skip_softmax_threshold field) to include Pydantic range constraints (e.g.,
pass ge=0 and le=1 to ModeloptField or change the type to a
pydantic.confloat(ge=0, le=1)) so parsing rejects negative or >1 values.
modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py (1)

48-59: ⚠️ Potential issue | 🟡 Minor

Restore previous _apply_skip_softmax state on exit.

The finally block unconditionally sets False, breaking nested/re-entrant contexts. Save and restore the old value instead.

💡 Proposed fix
         `@contextmanager`
         def _skip_softmax_context():
+            prev = getattr(module, "_apply_skip_softmax", False)
             module._apply_skip_softmax = True
             try:
                 yield
             finally:
-                module._apply_skip_softmax = False
+                module._apply_skip_softmax = prev
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`
around lines 48 - 59, The context manager returned by get_sparse_context sets
module._apply_skip_softmax to True but unconditionally resets it to False on
exit, breaking nested/re-entrant usage; modify the inner _skip_softmax_context
in get_sparse_context to save the previous value (e.g., prev =
module._apply_skip_softmax) before setting True, yield, and then restore
module._apply_skip_softmax = prev in the finally block so the original state is
preserved for nested contexts.
modelopt/torch/kernels/triton_fa.py (1)

586-590: ⚠️ Potential issue | 🟡 Minor

Validate skip_softmax_threshold before computing log2.

Invalid values (negative, NaN, inf, or >1) will cause incorrect behavior. Reject them early with a clear error.

🛡️ Proposed fix
+        import math
+
+        if skip_softmax_threshold is None or skip_softmax_threshold == 0.0:
+            apply_skip = False
+        elif not math.isfinite(skip_softmax_threshold) or not (0.0 < skip_softmax_threshold <= 1.0):
+            raise ValueError(
+                f"skip_softmax_threshold must be in (0, 1], or None/0 to disable, got {skip_softmax_threshold}"
+            )
+        else:
+            apply_skip = True
-        apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
         if apply_skip:
             skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale
         else:
             skip_threshold_log2 = 0.0
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 586 - 590, Validate
skip_softmax_threshold before using math.log2: when skip_softmax_threshold is
not None, check with math.isfinite and not math.isnan and ensure 0 <
skip_softmax_threshold <= 1 (treat 0 or None as "no skip"); if the value fails
these checks raise a ValueError with a clear message referencing
skip_softmax_threshold. Then keep the existing logic that sets apply_skip =
skip_softmax_threshold is not None and skip_softmax_threshold > 0.0 and compute
skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale only when
apply_skip; otherwise set skip_threshold_log2 = 0.0. Use the variable names
skip_softmax_threshold, apply_skip, skip_threshold_log2, and sm_scale to locate
and update the code.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 586-590: Validate skip_softmax_threshold before using math.log2:
when skip_softmax_threshold is not None, check with math.isfinite and not
math.isnan and ensure 0 < skip_softmax_threshold <= 1 (treat 0 or None as "no
skip"); if the value fails these checks raise a ValueError with a clear message
referencing skip_softmax_threshold. Then keep the existing logic that sets
apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
and compute skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale
only when apply_skip; otherwise set skip_threshold_log2 = 0.0. Use the variable
names skip_softmax_threshold, apply_skip, skip_threshold_log2, and sm_scale to
locate and update the code.

In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 99-107: Add validation to skip_softmax_threshold so values are
constrained to [0,1]: update the ModeloptField declaration for
skip_softmax_threshold in config.py (the skip_softmax_threshold field) to
include Pydantic range constraints (e.g., pass ge=0 and le=1 to ModeloptField or
change the type to a pydantic.confloat(ge=0, le=1)) so parsing rejects negative
or >1 values.

In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`:
- Around line 48-59: The context manager returned by get_sparse_context sets
module._apply_skip_softmax to True but unconditionally resets it to False on
exit, breaking nested/re-entrant usage; modify the inner _skip_softmax_context
in get_sparse_context to save the previous value (e.g., prev =
module._apply_skip_softmax) before setting True, yield, and then restore
module._apply_skip_softmax = prev in the finally block so the original state is
preserved for nested contexts.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 74df2426-ad1e-4293-983c-ab3ba946d16f

📥 Commits

Reviewing files that changed from the base of the PR and between c425524 and 7c966b1.

📒 Files selected for processing (8)
  • CHANGELOG.rst
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/methods/registry.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py (1)

37-48: Consider adding type hints for consistency.

The codebase uses type hints extensively. Adding annotations would improve static type checking with mypy.

💡 Suggested type hints
-    def __init__(self, method_config=None):
+    def __init__(self, method_config: dict | None = None):
         """Initialize with skip-softmax threshold from config."""
         super().__init__()
         method_config = method_config or {}
         self.skip_softmax_threshold = method_config.get("skip_softmax_threshold", 0.1)

     `@property`
     def name(self) -> str:
         """Method name identifier."""
         return "triton_skip_softmax"

-    def get_sparse_context(self, module):
+    def get_sparse_context(self, module: "torch.nn.Module"):
         """Return context manager that activates skip-softmax during forward."""

Note: You'll need to add from __future__ import annotations or import torch for the type hint.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`
around lines 37 - 48, Add static type annotations to the constructor and methods
in this file: annotate the __init__ parameter method_config (e.g.,
Optional[Dict[str, Any]]) and the instance attribute skip_softmax_threshold as
float, annotate the name property return type as str (already hinted but keep
consistent), and annotate get_sparse_context(module) with an appropriate type
for module (e.g., torch.nn.Module) and a return type (e.g., Any or a specific
SparseContext type). Also add the necessary imports (from __future__ import
annotations and typing imports like Optional, Dict, Any, plus import torch) at
the top so mypy/static checkers can validate the signatures for __init__, name,
and get_sparse_context.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`:
- Around line 37-48: Add static type annotations to the constructor and methods
in this file: annotate the __init__ parameter method_config (e.g.,
Optional[Dict[str, Any]]) and the instance attribute skip_softmax_threshold as
float, annotate the name property return type as str (already hinted but keep
consistent), and annotate get_sparse_context(module) with an appropriate type
for module (e.g., torch.nn.Module) and a return type (e.g., Any or a specific
SparseContext type). Also add the necessary imports (from __future__ import
annotations and typing imports like Optional, Dict, Any, plus import torch) at
the top so mypy/static checkers can validate the signatures for __init__, name,
and get_sparse_context.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b76f064d-ad1e-49e0-85b4-5a8ab65b40d9

📥 Commits

Reviewing files that changed from the base of the PR and between 7c966b1 and 59849a1.

📒 Files selected for processing (8)
  • CHANGELOG.rst
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/methods/registry.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
✅ Files skipped from review due to trivial changes (1)
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • modelopt/torch/sparsity/attention_sparsity/methods/init.py
  • CHANGELOG.rst
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/kernels/triton_fa.py

Signed-off-by: Kai Xu <kaix@nvidia.com>
@kaix-nv kaix-nv force-pushed the kaix/triton_fa_skip_softmax branch from 9a03035 to c49bca2 Compare March 26, 2026 02:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants