[3/n] Add skip-softmax to Triton flash attention kernel#1081
[3/n] Add skip-softmax to Triton flash attention kernel#1081
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a Triton-side skip-softmax tile-skipping optimization to flash attention, exposes a runtime Changes
Sequence Diagram(s)mermaid Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1081 +/- ##
==========================================
- Coverage 70.21% 70.19% -0.02%
==========================================
Files 228 229 +1
Lines 25952 25976 +24
==========================================
+ Hits 18221 18233 +12
- Misses 7731 7743 +12 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
9225466 to
cc0e9b3
Compare
cc0e9b3 to
270b94e
Compare
270b94e to
6c65ef3
Compare
6c65ef3 to
012fb20
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py (1)
494-505: Avoid asserting monotonic MAE from random samples.Lines 500-505 assume that increasing the threshold must increase
mean(abs(out_skip - out_dense)), but that is not guaranteed; extra skipped tiles can still reduce the final error through cancellation on a fixed input. This is likely to be flaky across seeds and GPU/dtype combinations. Prefer a directly monotonic signal, or weaken the expectation.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py` around lines 494 - 505, The test `test_monotonic_approximation_error` assumes mean absolute error increases strictly with skip_softmax_threshold, which is flaky; change the assertion to a weaker, robust check: compute errors for thresholds via attention(q,k,v,locs,lens,512,softmax_scale=scale,skip_softmax_threshold=...), then either remove the strict stepwise monotonic assertion and instead assert a single inequality between the smallest and largest thresholds with a tolerance (e.g., errors[0] <= errors[-1] + tol) or allow small per-step regressions by checking non-decrease within a small relative/absolute tolerance; update the final assert accordingly and keep references to the variables/functions used (attention, out_dense, out_skip, errors, skip_softmax_threshold).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 380-384: The backward pass is recomputing the skip mask from final
lse which can differ from the forward per-tile running row_max; instead persist
the exact forward skip decisions (or the pre-tile row_max used in forward) so
the backward replays them exactly: modify the forward path that computes
tile_row_max / can_skip (used when APPLY_SKIP_SOFTMAX) to store the boolean skip
mask (or the pre-tile max) alongside tensors needed for backward and have the
backward use that saved mask when zeroing p (rather than recomputing can_skip
from lse and SKIP_THRESHOLD_LOG2); as a short-term alternative, gate
APPLY_SKIP_SOFTMAX to inference-only until you add this saved metadata so
gradients remain correct.
In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 99-107: The skip_softmax_threshold field must be validated to
ensure it is a fraction in [0, 1]; update the config parsing/validation so
negative values or values >1 raise during parse rather than silently changing
kernel behavior. Modify the typed config that defines skip_softmax_threshold
(the ModeloptField) to enforce 0.0 <= skip_softmax_threshold <= 1.0 — either by
adding a pydantic validator for skip_softmax_threshold or adding an explicit
check in the config class constructor/__post_init__ that raises a ValueError
with a clear message if the constraint is violated. Ensure the error triggers
during config parse/instantiation so callers get immediate feedback.
---
Nitpick comments:
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py`:
- Around line 494-505: The test `test_monotonic_approximation_error` assumes
mean absolute error increases strictly with skip_softmax_threshold, which is
flaky; change the assertion to a weaker, robust check: compute errors for
thresholds via
attention(q,k,v,locs,lens,512,softmax_scale=scale,skip_softmax_threshold=...),
then either remove the strict stepwise monotonic assertion and instead assert a
single inequality between the smallest and largest thresholds with a tolerance
(e.g., errors[0] <= errors[-1] + tol) or allow small per-step regressions by
checking non-decrease within a small relative/absolute tolerance; update the
final assert accordingly and keep references to the variables/functions used
(attention, out_dense, out_skip, errors, skip_softmax_threshold).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 4a8c1dda-739d-4a3e-b939-e729f5e6858d
📥 Commits
Reviewing files that changed from the base of the PR and between 08e5f92 and 270b94eb73c2d1ce98f0ca3e7e478e51c1ef342f.
📒 Files selected for processing (6)
CHANGELOG.rstmodelopt/torch/kernels/hf_triton_attention.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/__init__.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
| skip_softmax_threshold: float = ModeloptField( | ||
| default=0.1, | ||
| title="Skip-softmax threshold.", | ||
| description=( | ||
| "Tiles contributing less than this fraction are skipped entirely. " | ||
| "Only used by triton_skip_softmax. Typical values: 1e-3 to 1e-1. " | ||
| "Set to 0 to disable." | ||
| ), | ||
| ) |
There was a problem hiding this comment.
Validate skip_softmax_threshold in the typed config.
Line 99 introduces a public fraction, but negative or >1 values currently pass validation and change kernel behavior in surprising ways. Reject them at parse time instead of silently treating them as “disabled” or “skip almost everything.”
🧩 Suggested constraint
skip_softmax_threshold: float = ModeloptField(
default=0.1,
title="Skip-softmax threshold.",
description=(
"Tiles contributing less than this fraction are skipped entirely. "
"Only used by triton_skip_softmax. Typical values: 1e-3 to 1e-1. "
"Set to 0 to disable."
),
+ ge=0.0,
+ le=1.0,
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 99 - 107,
The skip_softmax_threshold field must be validated to ensure it is a fraction in
[0, 1]; update the config parsing/validation so negative values or values >1
raise during parse rather than silently changing kernel behavior. Modify the
typed config that defines skip_softmax_threshold (the ModeloptField) to enforce
0.0 <= skip_softmax_threshold <= 1.0 — either by adding a pydantic validator for
skip_softmax_threshold or adding an explicit check in the config class
constructor/__post_init__ that raises a ValueError with a clear message if the
constraint is violated. Ensure the error triggers during config
parse/instantiation so callers get immediate feedback.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
modelopt/torch/kernels/triton_fa.py (1)
380-384:⚠️ Potential issue | 🔴 CriticalBackward still cannot replay the forward skip decisions.
ctxonly saves the scalar skip flag/threshold, so these backward kernels rebuildcan_skipfrom finallseinstead of the pre-tilerow_maxused in forward. Sincelseis always at least as large as the forward running max, backward can zero gradients for tiles that were kept in forward. Please either persist the exact forward mask / pre-tile max or keepskip_softmax_thresholdinference-only until backward can replay the same predicate. The public docstring should not claim “the same skip decision” until this is fixed.🛡️ Safe short-term guard
apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0 + if apply_skip and (q.requires_grad or k.requires_grad or v.requires_grad): + raise NotImplementedError( + "skip_softmax_threshold is inference-only until backward can replay " + "the exact forward skip decisions." + ) if apply_skip: import mathAlso applies to: 510-514, 627-628
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 380 - 384, The backward kernels are recomputing the skip predicate from lse, which differs from the forward pre-tile max and causes incorrect gradient zeroing; change the forward pass to save the exact per-tile skip mask or the pre-tile row_max into ctx (not just the scalar skip_softmax_threshold) and have the backward kernels (the code paths using APPLY_SKIP_SOFTMAX where can_skip is computed) read that saved mask/value from ctx to reconstruct the exact same can_skip used in forward; alternatively, make skip_softmax_threshold inference-only until backward can replay the same predicate and update the public docstring to stop claiming “the same skip decision” until fixed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 568-575: Validate the skip_softmax_threshold value before
computing skip_threshold_log2: treat only None or 0.0 as disabled, and raise a
ValueError for NaN, infinite, negative, or >1 values (accept only values in the
open interval (0, 1] for enabling). Update the logic around apply_skip,
skip_softmax_threshold, and skip_threshold_log2 to perform this check and raise
early with a clear message, and apply the same validation to the other
occurrence of the same pattern in this file (the block around the second
occurrence noted in the comment).
---
Duplicate comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 380-384: The backward kernels are recomputing the skip predicate
from lse, which differs from the forward pre-tile max and causes incorrect
gradient zeroing; change the forward pass to save the exact per-tile skip mask
or the pre-tile row_max into ctx (not just the scalar skip_softmax_threshold)
and have the backward kernels (the code paths using APPLY_SKIP_SOFTMAX where
can_skip is computed) read that saved mask/value from ctx to reconstruct the
exact same can_skip used in forward; alternatively, make skip_softmax_threshold
inference-only until backward can replay the same predicate and update the
public docstring to stop claiming “the same skip decision” until fixed.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: fa0b5612-c9b2-47f7-a1bf-cb211e19a57e
📥 Commits
Reviewing files that changed from the base of the PR and between 270b94eb73c2d1ce98f0ca3e7e478e51c1ef342f and 012fb20.
📒 Files selected for processing (6)
CHANGELOG.rstmodelopt/torch/kernels/hf_triton_attention.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/__init__.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
✅ Files skipped from review due to trivial changes (1)
- CHANGELOG.rst
🚧 Files skipped from review as they are similar to previous changes (4)
- modelopt/torch/kernels/hf_triton_attention.py
- modelopt/torch/sparsity/attention_sparsity/methods/init.py
- tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
- modelopt/torch/sparsity/attention_sparsity/config.py
modelopt/torch/kernels/triton_fa.py
Outdated
| # Skip-softmax: convert threshold to log2 space for the kernel | ||
| apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0 | ||
| if apply_skip: | ||
| import math | ||
|
|
||
| skip_threshold_log2 = math.log2(skip_softmax_threshold) | ||
| else: | ||
| skip_threshold_log2 = 0.0 |
There was a problem hiding this comment.
Reject invalid skip_softmax_threshold values up front.
This knob is documented as a contribution fraction, but the host-side parsing currently accepts nan, inf, negatives, and values above 1. That means a typo can either silently disable the feature or make later tiles overly skippable. Please reserve None/0 as the only disable cases and raise on anything outside (0, 1].
🧪 Proposed fix
- apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
- if apply_skip:
- import math
-
- skip_threshold_log2 = math.log2(skip_softmax_threshold)
+ import math
+
+ if skip_softmax_threshold is None or skip_softmax_threshold == 0.0:
+ apply_skip = False
+ else:
+ if not math.isfinite(skip_softmax_threshold) or not (0.0 < skip_softmax_threshold <= 1.0):
+ raise ValueError(
+ "skip_softmax_threshold must be a finite float in (0, 1], or None/0 to disable."
+ )
+ apply_skip = True
+
+ if apply_skip:
+ skip_threshold_log2 = math.log2(skip_softmax_threshold)
else:
skip_threshold_log2 = 0.0Also applies to: 762-768
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/kernels/triton_fa.py` around lines 568 - 575, Validate the
skip_softmax_threshold value before computing skip_threshold_log2: treat only
None or 0.0 as disabled, and raise a ValueError for NaN, infinite, negative, or
>1 values (accept only values in the open interval (0, 1] for enabling). Update
the logic around apply_skip, skip_softmax_threshold, and skip_threshold_log2 to
perform this check and raise early with a clear message, and apply the same
validation to the other occurrence of the same pattern in this file (the block
around the second occurrence noted in the comment).
012fb20 to
ecc5540
Compare
|
There was a problem hiding this comment.
♻️ Duplicate comments (3)
modelopt/torch/kernels/triton_fa.py (2)
380-384:⚠️ Potential issue | 🔴 CriticalDo not enable
skip_softmax_thresholdduring training yet.Forward skips against the pre-tile running
row_max, but these backward paths rebuildcan_skipfrom finallse. That can zero gradients for tiles that were not skipped in forward, so training with this flag is still incorrect. Please either persist the exact forward skip mask / pre-tile max or gate this mode to inference only.🛡️ Safe short-term guard
apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0 + if apply_skip and (q.requires_grad or k.requires_grad or v.requires_grad): + raise NotImplementedError( + "skip_softmax_threshold is inference-only until backward can replay " + "the exact forward skip decisions." + ) if apply_skip: import mathAlso applies to: 510-514
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 380 - 384, The backward code recomputes can_skip from lse and SKIP_THRESHOLD_LOG2 which can differ from the forward decision (tile_row_max) causing incorrect zeroed gradients when APPLY_SKIP_SOFTMAX (skip_softmax_threshold) is enabled; fix by either (A) persisting the exact forward skip mask (compute and store tile_row_max and/or can_skip from the forward pass and reuse that mask in the backward path when restoring p) or (B) disallowing this mode during training by gating APPLY_SKIP_SOFTMAX/skip_softmax_threshold to inference-only; update logic that references tile_row_max, can_skip, scores, lse and p to use the persisted mask or the inference-only guard accordingly.
568-575:⚠️ Potential issue | 🟡 MinorValidate
skip_softmax_thresholdbefore computinglog2.Only
Noneand0are documented disable cases. Negative, non-finite, or>1values currently either get silently treated as off or make skipping much more aggressive than the API contract suggests.🧪 Suggested input validation
- # Skip-softmax: convert threshold to log2 space for the kernel - apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0 - if apply_skip: - import math - - skip_threshold_log2 = math.log2(skip_softmax_threshold) - else: - skip_threshold_log2 = 0.0 + # Skip-softmax: convert threshold to log2 space for the kernel + import math + + if skip_softmax_threshold is None or skip_softmax_threshold == 0.0: + apply_skip = False + skip_threshold_log2 = 0.0 + else: + if not math.isfinite(skip_softmax_threshold) or not (0.0 < skip_softmax_threshold <= 1.0): + raise ValueError( + "skip_softmax_threshold must be a finite float in (0, 1], or None/0 to disable." + ) + apply_skip = True + skip_threshold_log2 = math.log2(skip_softmax_threshold)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 568 - 575, Validate skip_softmax_threshold before computing log2: when computing apply_skip and skip_threshold_log2 (use the variables skip_softmax_threshold, apply_skip, skip_threshold_log2 and the math.log2 call), ensure that if skip_softmax_threshold is not None it is a finite numeric value and within the documented range (0 < value <= 1); treat 0 or None as “off”; for values that are negative, non-finite (NaN/inf) or >1 raise a clear ValueError (or TypeError for wrong type) with a message explaining allowed values so the code never silently treats invalid inputs as off or miscomputes the log2.modelopt/torch/sparsity/attention_sparsity/config.py (1)
99-107:⚠️ Potential issue | 🟡 MinorValidate
skip_softmax_thresholdduring config parsing.This new public fraction still accepts negatives, non-finite values, and values above
1, which makes the Triton path either silently disable skipping or skip far too aggressively. Reject invalid values when the config is instantiated instead of relying on runtime behavior.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 99 - 107, The public field skip_softmax_threshold can be negative, non-finite, or >1; add validation at config instantiation so invalid values are rejected early: in the config class that defines skip_softmax_threshold (the class using ModeloptField), implement a validation step (e.g. __post_init__ or a pydantic/ModeloptField validator) that checks the value is finite and 0.0 <= skip_softmax_threshold <= 1.0 and raise a ValueError with a clear message if not; this ensures invalid inputs are caught when the config is created rather than at runtime in triton_skip_softmax.
🧹 Nitpick comments (1)
modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py (1)
65-71: Restore the previous module flag infinally.This context manager always writes
Falseon exit, so nested or stacked uses on the same module can clobber an outer active context. Restoring the prior value makes the activation state composable.♻️ Suggested fix
`@contextmanager` def _skip_softmax_context(): + prev_apply_skip_softmax = getattr(module, "_apply_skip_softmax", False) module._apply_skip_softmax = True try: yield finally: - module._apply_skip_softmax = False + module._apply_skip_softmax = prev_apply_skip_softmax🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py` around lines 65 - 71, The _skip_softmax_context context manager currently overwrites module._apply_skip_softmax to False on exit, which breaks nested contexts; modify _skip_softmax_context to save the prior value (e.g., prev = module._apply_skip_softmax), set module._apply_skip_softmax = True on entry, and in the finally block restore module._apply_skip_softmax = prev so nested or stacked uses of the context preserve outer states (apply this change inside the _skip_softmax_context definition).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 380-384: The backward code recomputes can_skip from lse and
SKIP_THRESHOLD_LOG2 which can differ from the forward decision (tile_row_max)
causing incorrect zeroed gradients when APPLY_SKIP_SOFTMAX
(skip_softmax_threshold) is enabled; fix by either (A) persisting the exact
forward skip mask (compute and store tile_row_max and/or can_skip from the
forward pass and reuse that mask in the backward path when restoring p) or (B)
disallowing this mode during training by gating
APPLY_SKIP_SOFTMAX/skip_softmax_threshold to inference-only; update logic that
references tile_row_max, can_skip, scores, lse and p to use the persisted mask
or the inference-only guard accordingly.
- Around line 568-575: Validate skip_softmax_threshold before computing log2:
when computing apply_skip and skip_threshold_log2 (use the variables
skip_softmax_threshold, apply_skip, skip_threshold_log2 and the math.log2 call),
ensure that if skip_softmax_threshold is not None it is a finite numeric value
and within the documented range (0 < value <= 1); treat 0 or None as “off”; for
values that are negative, non-finite (NaN/inf) or >1 raise a clear ValueError
(or TypeError for wrong type) with a message explaining allowed values so the
code never silently treats invalid inputs as off or miscomputes the log2.
In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 99-107: The public field skip_softmax_threshold can be negative,
non-finite, or >1; add validation at config instantiation so invalid values are
rejected early: in the config class that defines skip_softmax_threshold (the
class using ModeloptField), implement a validation step (e.g. __post_init__ or a
pydantic/ModeloptField validator) that checks the value is finite and 0.0 <=
skip_softmax_threshold <= 1.0 and raise a ValueError with a clear message if
not; this ensures invalid inputs are caught when the config is created rather
than at runtime in triton_skip_softmax.
---
Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`:
- Around line 65-71: The _skip_softmax_context context manager currently
overwrites module._apply_skip_softmax to False on exit, which breaks nested
contexts; modify _skip_softmax_context to save the prior value (e.g., prev =
module._apply_skip_softmax), set module._apply_skip_softmax = True on entry, and
in the finally block restore module._apply_skip_softmax = prev so nested or
stacked uses of the context preserve outer states (apply this change inside the
_skip_softmax_context definition).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: b3835fd4-3e45-4467-a16f-8477c8ba3c2c
📒 Files selected for processing (7)
CHANGELOG.rstmodelopt/torch/kernels/hf_triton_attention.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/__init__.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
✅ Files skipped from review due to trivial changes (2)
- CHANGELOG.rst
- modelopt/torch/sparsity/attention_sparsity/methods/init.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
modelopt/torch/kernels/triton_fa.py
Outdated
| m_new = tl.maximum(row_max, tile_row_max) | ||
| p = tl.math.exp2(scores - m_new[:, None]) | ||
| # Zero out skipped rows (instead of masking scores and recomputing max) | ||
| p = tl.where(can_skip[:, None], 0.0, p) |
There was a problem hiding this comment.
here's confusing, if not all_skip, we don't do the skip for this tile, then why we add the p = tl.where(can_skip[:, None], 0.0, p) to here since we skip nothing for this tile, and we are doing the tile level skipping
There was a problem hiding this comment.
I guess all_skip means that, for this specific tile, it’s true if we skip it and false otherwise. If that’s the case, the variable name is confusing as well._
There was a problem hiding this comment.
seems to me, this line should be deleted
Edwardf0t1
left a comment
There was a problem hiding this comment.
LGTM overall, left some comments.
| # Re-apply skip-softmax: zero out rows that were skipped in forward | ||
| if APPLY_SKIP_SOFTMAX: | ||
| tile_row_max = tl.max(scores, 1) | ||
| can_skip = tile_row_max < (lse + SKIP_THRESHOLD_LOG2) |
There was a problem hiding this comment.
So forward and backward is different — forward skips based on row_max, backward on lse.
The lse >= row_max always holds, so the backward threshold is strictly looser — it will skip fewer tiles than the forward. This means the backward may compute gradients for tiles that were skipped in the forward.
Is this intentional?
| if APPLY_SKIP_SOFTMAX: | ||
| # --- Skip-softmax path: check tile, skip V load if all rows negligible --- | ||
| # Compute tile row max once — reused for both skip check and softmax update | ||
| tile_row_max = tl.max(scores, 1) # [BLOCK_M] | ||
| can_skip = tile_row_max < (row_max + SKIP_THRESHOLD_LOG2) | ||
| all_skip = tl.min(can_skip.to(tl.int32)) == 1 | ||
|
|
||
| if not all_skip: | ||
| # Online softmax update (reuses tile_row_max — no second tl.max) | ||
| # For skipped rows: tile_row_max < row_max, so m_new = row_max (no change) | ||
| m_new = tl.maximum(row_max, tile_row_max) | ||
| p = tl.math.exp2(scores - m_new[:, None]) | ||
| # Zero out skipped rows (instead of masking scores and recomputing max) | ||
| p = tl.where(can_skip[:, None], 0.0, p) | ||
| l_new = tl.sum(p, 1) | ||
| correction = tl.math.exp2(row_max - m_new) | ||
| row_sum = row_sum * correction + l_new | ||
| acc = acc * correction[:, None] | ||
|
|
||
| # Load V and accumulate | ||
| v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] | ||
| v = tl.load( | ||
| v_base + v_offs, | ||
| mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], | ||
| other=0.0, | ||
| ) | ||
| acc = tl.dot(p.to(v.dtype), v, acc) | ||
| row_max = m_new | ||
| # else: all rows negligible — skip V load, softmax update, accumulation | ||
| else: | ||
| # --- Standard path: no skip check --- | ||
| # Online softmax update | ||
| m_new = tl.maximum(row_max, tl.max(scores, 1)) | ||
| p = tl.math.exp2(scores - m_new[:, None]) | ||
| l_new = tl.sum(p, 1) | ||
| correction = tl.math.exp2(row_max - m_new) | ||
| row_sum = row_sum * correction + l_new | ||
| acc = acc * correction[:, None] | ||
|
|
||
| # Load V and accumulate | ||
| v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] | ||
| v = tl.load( | ||
| v_base + v_offs, | ||
| mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], | ||
| other=0.0, | ||
| ) | ||
| acc = tl.dot(p.to(v.dtype), v, acc) | ||
| row_max = m_new |
There was a problem hiding this comment.
It seems one path for APPLY_SKIP_SOFTMAX=True and one for False. The standard path and the "not all_skip" branch inside the skip path are nearly identical. Since APPLY_SKIP_SOFTMAX is a tl.constexpr, Triton will compile-time eliminate the dead branch, but the source code duplication makes maintenance harder.
Consider structuring as:
# Skip check (compiled out when APPLY_SKIP_SOFTMAX=False)
do_process = True
if APPLY_SKIP_SOFTMAX:
tile_row_max = tl.max(scores, 1)
can_skip = tile_row_max < (row_max + SKIP_THRESHOLD_LOG2)
all_skip = tl.min(can_skip.to(tl.int32)) == 1
if all_skip:
do_process = False
if do_process:
# Single copy of softmax update + V accumulation
...
if APPLY_SKIP_SOFTMAX:
p = tl.where(can_skip[:, None], 0.0, p)
...
|
@coderabbitai full review |
✅ Actions performedFull review triggered. |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (4)
modelopt/torch/kernels/triton_fa.py (2)
568-575:⚠️ Potential issue | 🟠 MajorReject invalid
skip_softmax_thresholdvalues up front.Direct callers can bypass the config path entirely, and this currently accepts negatives, non-finite values, and values above
1. In particular,infwill make every post-first tile trivially skippable and can drive the kernel into NaNs. Please reserveNone/0as the disable cases and raise on anything outside finite(0, 1].🧪 Suggested fix
- apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0 - if apply_skip: - import math - - skip_threshold_log2 = math.log2(skip_softmax_threshold) - else: - skip_threshold_log2 = 0.0 + import math + + if skip_softmax_threshold is None or skip_softmax_threshold == 0.0: + apply_skip = False + skip_threshold_log2 = 0.0 + else: + if not math.isfinite(skip_softmax_threshold) or not ( + 0.0 < skip_softmax_threshold <= 1.0 + ): + raise ValueError( + "skip_softmax_threshold must be a finite float in (0, 1], " + "or None/0 to disable." + ) + apply_skip = True + skip_threshold_log2 = math.log2(skip_softmax_threshold)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 568 - 575, Validate skip_softmax_threshold up front: treat only None or 0.0 as disable, otherwise require a finite number in the range (0, 1]; use math.isfinite and raise ValueError for negatives, zero (except explicit 0.0 disable), non-finite (inf/NaN), or >1. Then compute skip_threshold_log2 as math.log2(skip_softmax_threshold) and set apply_skip based on that validation (use the existing symbols skip_softmax_threshold, apply_skip, skip_threshold_log2).
380-384:⚠️ Potential issue | 🔴 CriticalBackward is not replaying the forward skip decisions.
Forward decides
can_skipfrom the per-tile runningrow_max, but backward recomputes it from finallse. Sincelseis not the same state and is always at least as large as the pre-tile max, these branches can zero gradients for tiles that were actually kept in forward. That makesskip_softmax_thresholdgradient-incorrect whenever any input requires grad. Please either save the exact forward skip mask / pre-tile max for backward, or gate this mode to inference-only until that metadata exists.🛡️ Safe short-term guard
if apply_skip: import math skip_threshold_log2 = math.log2(skip_softmax_threshold) else: skip_threshold_log2 = 0.0 + + if apply_skip and (q.requires_grad or k.requires_grad or v.requires_grad): + raise NotImplementedError( + "skip_softmax_threshold is inference-only until backward can replay " + "the exact forward skip decisions." + )Also applies to: 510-514
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 380 - 384, The backward pass is recomputing the skip decision from `lse` (using `SKIP_THRESHOLD_LOG2`) which differs from the forward per-tile `tile_row_max`, causing incorrect gradients; fix by persisting the exact forward skip decision (save `can_skip` or `tile_row_max` from the forward path) and use that saved mask in the backward to zero out `p` (instead of recomputing from `lse`), or, until that metadata is stored, disable `APPLY_SKIP_SOFTMAX` for any inputs that require gradients (gate the optimization to inference-only); update the forward to stash the mask (or pre-tile max) and update the backward to read and apply that saved mask when processing `p`.modelopt/torch/sparsity/attention_sparsity/config.py (1)
99-107:⚠️ Potential issue | 🟡 MinorValidate
skip_softmax_thresholdduring config parse.This is a public fraction field, but negatives and values above
1still validate here and only fail later as confusing kernel behavior. Please reject anything outside[0.0, 1.0]at parse time.🧪 Suggested fix
skip_softmax_threshold: float = ModeloptField( default=0.1, title="Skip-softmax threshold.", description=( "Tiles contributing less than this fraction are skipped entirely. " "Only used by triton_skip_softmax. Typical values: 1e-3 to 1e-1. " "Set to 0 to disable." ), + ge=0.0, + le=1.0, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 99 - 107, Add validation for skip_softmax_threshold in the config parsing so values outside [0.0, 1.0] are rejected early: in the code that constructs/parses the config (the ModeloptField definition for skip_softmax_threshold or the config class's validation/__post_init__/validate method), check the skip_softmax_threshold value and raise a clear ValueError (or use the config validation mechanism) if skip_softmax_threshold < 0.0 or skip_softmax_threshold > 1.0, ensuring any negative or >1 inputs fail at parse time rather than later in the Triton kernel.tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py (1)
562-600:⚠️ Potential issue | 🟡 MinorThis integration test still doesn’t prove skip-softmax was enabled.
The case is intentionally short and then asserts
logits_skip ~= logits_dense, so it still passes iftriton_attention_forward()never forwardsskip_softmax_thresholdor if the threshold never causes any tile to be skipped on this model. Please force a multi-tile prompt / larger tiny-model fixture, or assert that the kwarg reachesattention()directly.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py` around lines 562 - 600, The test currently never verifies that skip-softmax was actually enabled; update test_skip_softmax_via_sparsify to either (A) force multi-tile behavior by using a longer input (increase ids length beyond a single tile/sequence chunk) or set SKIP_SOFTMAX_TRITON_DEFAULT to a value that will trigger skipping, and then check outputs, or (B) directly assert the kwarg is forwarded by monkeypatching/wrapping the attention implementation (wrap triton_attention_forward or the model's attention() method instances obtained from the loaded AutoModelForCausalLM) to capture its kwargs and assert skip_softmax_threshold (or a boolean like skip_softmax) is present and set; reference functions/classes: test_skip_softmax_via_sparsify, mtsa.sparsify, mtsa.SKIP_SOFTMAX_TRITON_DEFAULT, triton_attention_forward, attention().
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py`:
- Around line 494-505: Remove the unstable total ordering assertion in
test_monotonic_approximation_error that requires errors[0] <= errors[1] <=
errors[2]; instead keep deterministic sanity checks: ensure each computed error
is finite and non-negative (e.g., errors[i] is not NaN and errors[i] >= 0) and
replace the strict chain with a single small-vs-large check between the smallest
and largest threshold (e.g., errors[0] <= errors[2]) to validate that very small
thresholds produce no larger error than very large thresholds; update references
to attention, skip_softmax_threshold, out_dense, out_skip, and errors
accordingly.
---
Duplicate comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 568-575: Validate skip_softmax_threshold up front: treat only None
or 0.0 as disable, otherwise require a finite number in the range (0, 1]; use
math.isfinite and raise ValueError for negatives, zero (except explicit 0.0
disable), non-finite (inf/NaN), or >1. Then compute skip_threshold_log2 as
math.log2(skip_softmax_threshold) and set apply_skip based on that validation
(use the existing symbols skip_softmax_threshold, apply_skip,
skip_threshold_log2).
- Around line 380-384: The backward pass is recomputing the skip decision from
`lse` (using `SKIP_THRESHOLD_LOG2`) which differs from the forward per-tile
`tile_row_max`, causing incorrect gradients; fix by persisting the exact forward
skip decision (save `can_skip` or `tile_row_max` from the forward path) and use
that saved mask in the backward to zero out `p` (instead of recomputing from
`lse`), or, until that metadata is stored, disable `APPLY_SKIP_SOFTMAX` for any
inputs that require gradients (gate the optimization to inference-only); update
the forward to stash the mask (or pre-tile max) and update the backward to read
and apply that saved mask when processing `p`.
In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 99-107: Add validation for skip_softmax_threshold in the config
parsing so values outside [0.0, 1.0] are rejected early: in the code that
constructs/parses the config (the ModeloptField definition for
skip_softmax_threshold or the config class's validation/__post_init__/validate
method), check the skip_softmax_threshold value and raise a clear ValueError (or
use the config validation mechanism) if skip_softmax_threshold < 0.0 or
skip_softmax_threshold > 1.0, ensuring any negative or >1 inputs fail at parse
time rather than later in the Triton kernel.
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py`:
- Around line 562-600: The test currently never verifies that skip-softmax was
actually enabled; update test_skip_softmax_via_sparsify to either (A) force
multi-tile behavior by using a longer input (increase ids length beyond a single
tile/sequence chunk) or set SKIP_SOFTMAX_TRITON_DEFAULT to a value that will
trigger skipping, and then check outputs, or (B) directly assert the kwarg is
forwarded by monkeypatching/wrapping the attention implementation (wrap
triton_attention_forward or the model's attention() method instances obtained
from the loaded AutoModelForCausalLM) to capture its kwargs and assert
skip_softmax_threshold (or a boolean like skip_softmax) is present and set;
reference functions/classes: test_skip_softmax_via_sparsify, mtsa.sparsify,
mtsa.SKIP_SOFTMAX_TRITON_DEFAULT, triton_attention_forward, attention().
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: cfbbdb8e-2c07-4cbd-8692-fb6b141d1661
📒 Files selected for processing (7)
CHANGELOG.rstmodelopt/torch/kernels/hf_triton_attention.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/__init__.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
| def test_monotonic_approximation_error(self): | ||
| """Larger threshold -> larger error vs dense (monotonic degradation).""" | ||
| q, k, v, locs, lens = self._make_inputs(seq_len=512) | ||
| scale = 1.0 / (64**0.5) | ||
| out_dense = attention(q, k, v, locs, lens, 512, softmax_scale=scale) | ||
| errors = [] | ||
| for threshold in [1e-4, 1e-2, 1e-1]: | ||
| out_skip = attention( | ||
| q, k, v, locs, lens, 512, softmax_scale=scale, skip_softmax_threshold=threshold | ||
| ) | ||
| errors.append((out_skip - out_dense).abs().mean().item()) | ||
| assert errors[0] <= errors[1] <= errors[2], f"Errors not monotonic: {errors}" |
There was a problem hiding this comment.
errors[0] <= errors[1] <= errors[2] is not a stable correctness invariant.
Higher thresholds skip more work, but the output error vs. dense can still decrease because the dropped V contributions can cancel. That makes this assertion a potential flake on a correct kernel. I’d keep the existing small/large-threshold sanity checks and drop this total ordering unless you can assert a more direct monotonic signal.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py` around lines
494 - 505, Remove the unstable total ordering assertion in
test_monotonic_approximation_error that requires errors[0] <= errors[1] <=
errors[2]; instead keep deterministic sanity checks: ensure each computed error
is finite and non-negative (e.g., errors[i] is not NaN and errors[i] >= 0) and
replace the strict chain with a single small-vs-large check between the smallest
and largest threshold (e.g., errors[0] <= errors[2]) to validate that very small
thresholds produce no larger error than very large thresholds; update references
to attention, skip_softmax_threshold, out_dense, out_skip, and errors
accordingly.
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (1)
modelopt/torch/kernels/triton_fa.py (1)
169-190:⚠️ Potential issue | 🔴 CriticalBackward is still replaying a different skip rule than forward.
Forward only skips when
skip_tileis true, using the per-tile runningrow_max. Backward rebuildscan_skipfrom finallseand then zeros rows individually, so mixed tiles and later tiles can drop gradients for work that was actually done in forward. Please save the exact forward tile-skip decision (or the pre-tilerow_max) and replay that, or keep this mode inference-only.🛡️ Safe short-term guard
apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0 + if apply_skip and (q.requires_grad or k.requires_grad or v.requires_grad): + raise NotImplementedError( + "skip_softmax_threshold is inference-only until backward can replay " + "the exact forward tile-skip decisions." + ) if apply_skip: skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scaleAlso applies to: 393-397, 523-527
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 169 - 190, Forward computes tile_row_max, can_skip and skip_tile using row_max and SKIP_THRESHOLD_LOG2 but the backward recomputes can_skip from lse causing mismatch; to fix, capture and store the exact forward decision (e.g., a boolean per-tile mask like skip_tile_mask or the pre-tile row_max values computed by tile_row_max) inside the forward path (near tile_row_max / can_skip / skip_tile) and have the backward path replay that saved mask/value instead of recomputing from lse so gradients are only zeroed for tiles actually skipped in forward (apply same change to the other occurrences referenced around lines 393-397 and 523-527).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 581-590: The skip-softmax threshold is being over-scaled: when
computing skip_threshold_log2 inside the function handling
skip_softmax_threshold, do not multiply math.log2(skip_softmax_threshold) by
sm_scale because scores are already in qk_scale (raw_score * sm_scale * LOG2E)
space; set skip_threshold_log2 = math.log2(skip_softmax_threshold) when
apply_skip is true (leave the else branch as 0.0), and keep references to
skip_softmax_threshold, skip_threshold_log2, sm_scale, qk_scale, and scores to
locate and update the calculation.
In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`:
- Around line 51-57: _the _skip_softmax_context context manager currently
unconditionally sets module._apply_skip_softmax = False on exit, which breaks
nested contexts; modify _skip_softmax_context to save the previous value (e.g.,
old = getattr(module, "_apply_skip_softmax", False)) before setting it to True,
and in the finally block restore module._apply_skip_softmax = old so
nested/re-entrant uses of _skip_softmax_context correctly preserve outer state._
---
Duplicate comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 169-190: Forward computes tile_row_max, can_skip and skip_tile
using row_max and SKIP_THRESHOLD_LOG2 but the backward recomputes can_skip from
lse causing mismatch; to fix, capture and store the exact forward decision
(e.g., a boolean per-tile mask like skip_tile_mask or the pre-tile row_max
values computed by tile_row_max) inside the forward path (near tile_row_max /
can_skip / skip_tile) and have the backward path replay that saved mask/value
instead of recomputing from lse so gradients are only zeroed for tiles actually
skipped in forward (apply same change to the other occurrences referenced around
lines 393-397 and 523-527).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 6ca1822c-2d5d-4782-ba9d-52393ce9a916
📒 Files selected for processing (4)
modelopt/torch/kernels/hf_triton_attention.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/methods/registry.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/kernels/hf_triton_attention.py
| # Skip-softmax: convert threshold to scaled log2 space for the kernel. | ||
| # The BLASST reference (https://arxiv.org/pdf/2512.12087) checks | ||
| # ln(lambda) on unscaled scores. Our kernel works in log2-scaled space | ||
| # (scores pre-multiplied by qk_scale = sm_scale * LOG2E), so we | ||
| # pre-scale: threshold_scaled = log2(lambda) * sm_scale. | ||
| apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0 | ||
| if apply_skip: | ||
| skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale | ||
| else: | ||
| skip_threshold_log2 = 0.0 |
There was a problem hiding this comment.
Don't multiply the skip cutoff by sm_scale again.
scores already live in raw_score * sm_scale * log2(e) space via qk_scale, so the contribution cutoff in that same space is just log2(lambda). Multiplying by sm_scale here makes the effective threshold head-dim dependent and much looser than the documented fraction.
🐛 Proposed fix
- # pre-scale: threshold_scaled = log2(lambda) * sm_scale.
+ # In kernel space, score deltas already include `sm_scale * log2(e)`,
+ # so the contribution cutoff is simply `log2(lambda)`.
apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
if apply_skip:
- skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale
+ skip_threshold_log2 = math.log2(skip_softmax_threshold)
else:
skip_threshold_log2 = 0.0🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/kernels/triton_fa.py` around lines 581 - 590, The skip-softmax
threshold is being over-scaled: when computing skip_threshold_log2 inside the
function handling skip_softmax_threshold, do not multiply
math.log2(skip_softmax_threshold) by sm_scale because scores are already in
qk_scale (raw_score * sm_scale * LOG2E) space; set skip_threshold_log2 =
math.log2(skip_softmax_threshold) when apply_skip is true (leave the else branch
as 0.0), and keep references to skip_softmax_threshold, skip_threshold_log2,
sm_scale, qk_scale, and scores to locate and update the calculation.
| @contextmanager | ||
| def _skip_softmax_context(): | ||
| module._apply_skip_softmax = True | ||
| try: | ||
| yield | ||
| finally: | ||
| module._apply_skip_softmax = False |
There was a problem hiding this comment.
Restore the previous _apply_skip_softmax state on exit.
The finally block always writes False, so a nested/re-entrant context on the same module will clear an outer skip-softmax scope too early. Save the old value and restore it instead.
💡 Proposed fix
`@contextmanager`
def _skip_softmax_context():
- module._apply_skip_softmax = True
+ prev_apply_skip_softmax = getattr(module, "_apply_skip_softmax", False)
+ module._apply_skip_softmax = True
try:
yield
finally:
- module._apply_skip_softmax = False
+ module._apply_skip_softmax = prev_apply_skip_softmax🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`
around lines 51 - 57, _the _skip_softmax_context context manager currently
unconditionally sets module._apply_skip_softmax = False on exit, which breaks
nested contexts; modify _skip_softmax_context to save the previous value (e.g.,
old = getattr(module, "_apply_skip_softmax", False)) before setting it to True,
and in the finally block restore module._apply_skip_softmax = old so
nested/re-entrant uses of _skip_softmax_context correctly preserve outer state._
|
@coderabbitai full review |
✅ Actions performedFull review triggered. |
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
7c966b1 to
59849a1
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (3)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
99-107:⚠️ Potential issue | 🟡 MinorAdd range validation for
skip_softmax_threshold.The field accepts any float, but values outside
[0, 1]are invalid (negative values or>1would cause unexpected kernel behavior). Add Pydantic constraints to reject invalid values at parse time.🧩 Suggested constraint
skip_softmax_threshold: float = ModeloptField( default=0.1, title="Skip-softmax threshold.", description=( "Tiles contributing less than this fraction are skipped entirely. " "Only used by triton_skip_softmax. Typical values: 1e-3 to 1e-1. " "Set to 0 to disable." ), + ge=0.0, + le=1.0, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 99 - 107, Add validation to skip_softmax_threshold so values are constrained to [0,1]: update the ModeloptField declaration for skip_softmax_threshold in config.py (the skip_softmax_threshold field) to include Pydantic range constraints (e.g., pass ge=0 and le=1 to ModeloptField or change the type to a pydantic.confloat(ge=0, le=1)) so parsing rejects negative or >1 values.modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py (1)
48-59:⚠️ Potential issue | 🟡 MinorRestore previous
_apply_skip_softmaxstate on exit.The
finallyblock unconditionally setsFalse, breaking nested/re-entrant contexts. Save and restore the old value instead.💡 Proposed fix
`@contextmanager` def _skip_softmax_context(): + prev = getattr(module, "_apply_skip_softmax", False) module._apply_skip_softmax = True try: yield finally: - module._apply_skip_softmax = False + module._apply_skip_softmax = prev🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py` around lines 48 - 59, The context manager returned by get_sparse_context sets module._apply_skip_softmax to True but unconditionally resets it to False on exit, breaking nested/re-entrant usage; modify the inner _skip_softmax_context in get_sparse_context to save the previous value (e.g., prev = module._apply_skip_softmax) before setting True, yield, and then restore module._apply_skip_softmax = prev in the finally block so the original state is preserved for nested contexts.modelopt/torch/kernels/triton_fa.py (1)
586-590:⚠️ Potential issue | 🟡 MinorValidate
skip_softmax_thresholdbefore computinglog2.Invalid values (negative, NaN, inf, or >1) will cause incorrect behavior. Reject them early with a clear error.
🛡️ Proposed fix
+ import math + + if skip_softmax_threshold is None or skip_softmax_threshold == 0.0: + apply_skip = False + elif not math.isfinite(skip_softmax_threshold) or not (0.0 < skip_softmax_threshold <= 1.0): + raise ValueError( + f"skip_softmax_threshold must be in (0, 1], or None/0 to disable, got {skip_softmax_threshold}" + ) + else: + apply_skip = True - apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0 if apply_skip: skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale else: skip_threshold_log2 = 0.0🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 586 - 590, Validate skip_softmax_threshold before using math.log2: when skip_softmax_threshold is not None, check with math.isfinite and not math.isnan and ensure 0 < skip_softmax_threshold <= 1 (treat 0 or None as "no skip"); if the value fails these checks raise a ValueError with a clear message referencing skip_softmax_threshold. Then keep the existing logic that sets apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0 and compute skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale only when apply_skip; otherwise set skip_threshold_log2 = 0.0. Use the variable names skip_softmax_threshold, apply_skip, skip_threshold_log2, and sm_scale to locate and update the code.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 586-590: Validate skip_softmax_threshold before using math.log2:
when skip_softmax_threshold is not None, check with math.isfinite and not
math.isnan and ensure 0 < skip_softmax_threshold <= 1 (treat 0 or None as "no
skip"); if the value fails these checks raise a ValueError with a clear message
referencing skip_softmax_threshold. Then keep the existing logic that sets
apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
and compute skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale
only when apply_skip; otherwise set skip_threshold_log2 = 0.0. Use the variable
names skip_softmax_threshold, apply_skip, skip_threshold_log2, and sm_scale to
locate and update the code.
In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 99-107: Add validation to skip_softmax_threshold so values are
constrained to [0,1]: update the ModeloptField declaration for
skip_softmax_threshold in config.py (the skip_softmax_threshold field) to
include Pydantic range constraints (e.g., pass ge=0 and le=1 to ModeloptField or
change the type to a pydantic.confloat(ge=0, le=1)) so parsing rejects negative
or >1 values.
In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`:
- Around line 48-59: The context manager returned by get_sparse_context sets
module._apply_skip_softmax to True but unconditionally resets it to False on
exit, breaking nested/re-entrant usage; modify the inner _skip_softmax_context
in get_sparse_context to save the previous value (e.g., prev =
module._apply_skip_softmax) before setting True, yield, and then restore
module._apply_skip_softmax = prev in the finally block so the original state is
preserved for nested contexts.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 74df2426-ad1e-4293-983c-ab3ba946d16f
📒 Files selected for processing (8)
CHANGELOG.rstmodelopt/torch/kernels/hf_triton_attention.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/__init__.pymodelopt/torch/sparsity/attention_sparsity/methods/registry.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
There was a problem hiding this comment.
🧹 Nitpick comments (1)
modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py (1)
37-48: Consider adding type hints for consistency.The codebase uses type hints extensively. Adding annotations would improve static type checking with mypy.
💡 Suggested type hints
- def __init__(self, method_config=None): + def __init__(self, method_config: dict | None = None): """Initialize with skip-softmax threshold from config.""" super().__init__() method_config = method_config or {} self.skip_softmax_threshold = method_config.get("skip_softmax_threshold", 0.1) `@property` def name(self) -> str: """Method name identifier.""" return "triton_skip_softmax" - def get_sparse_context(self, module): + def get_sparse_context(self, module: "torch.nn.Module"): """Return context manager that activates skip-softmax during forward."""Note: You'll need to add
from __future__ import annotationsorimport torchfor the type hint.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py` around lines 37 - 48, Add static type annotations to the constructor and methods in this file: annotate the __init__ parameter method_config (e.g., Optional[Dict[str, Any]]) and the instance attribute skip_softmax_threshold as float, annotate the name property return type as str (already hinted but keep consistent), and annotate get_sparse_context(module) with an appropriate type for module (e.g., torch.nn.Module) and a return type (e.g., Any or a specific SparseContext type). Also add the necessary imports (from __future__ import annotations and typing imports like Optional, Dict, Any, plus import torch) at the top so mypy/static checkers can validate the signatures for __init__, name, and get_sparse_context.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`:
- Around line 37-48: Add static type annotations to the constructor and methods
in this file: annotate the __init__ parameter method_config (e.g.,
Optional[Dict[str, Any]]) and the instance attribute skip_softmax_threshold as
float, annotate the name property return type as str (already hinted but keep
consistent), and annotate get_sparse_context(module) with an appropriate type
for module (e.g., torch.nn.Module) and a return type (e.g., Any or a specific
SparseContext type). Also add the necessary imports (from __future__ import
annotations and typing imports like Optional, Dict, Any, plus import torch) at
the top so mypy/static checkers can validate the signatures for __init__, name,
and get_sparse_context.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: b76f064d-ad1e-49e0-85b4-5a8ab65b40d9
📒 Files selected for processing (8)
CHANGELOG.rstmodelopt/torch/kernels/hf_triton_attention.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/__init__.pymodelopt/torch/sparsity/attention_sparsity/methods/registry.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
✅ Files skipped from review due to trivial changes (1)
- tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
🚧 Files skipped from review as they are similar to previous changes (5)
- modelopt/torch/sparsity/attention_sparsity/methods/init.py
- CHANGELOG.rst
- modelopt/torch/kernels/hf_triton_attention.py
- modelopt/torch/sparsity/attention_sparsity/config.py
- modelopt/torch/kernels/triton_fa.py
Signed-off-by: Kai Xu <kaix@nvidia.com>
9a03035 to
c49bca2
Compare
What does this PR do?
Type of change: ?
New feature. Add skip-softmax tile skipping to the Triton flash attention kernel.
Usage
Testing
Performance (TFLOPS at seq_len=16384, RTX 6000 Pro):
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
New Features
Tests
Documentation