Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ NVIDIA Model Optimizer Changelog
- Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics.
- Added iterator interface using CalibrationDataReader in ONNX quantization workflow.
- Add N:M sparse softmax support to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.
- Add skip-softmax skipping to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.
- Enable PTQ workflow for the Step3.5-Flash MoE model with NVFP4 W4A4 + FP8 KV cache quantization. See `modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml>`_ for more details.

**Bug Fixes**
Expand Down
24 changes: 14 additions & 10 deletions modelopt/torch/kernels/hf_triton_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,20 @@ def triton_attention_forward(
kw["b_seq_len_k"] = torch.full((batch,), seq_k, device=device, dtype=torch.int32)
kw["max_input_len_k"] = seq_k

# N:M sparse softmax — prefill only (decode should not sparsify KV)
if not is_decode and getattr(module, "_apply_sparse_nm", False):
# _sparse_method_instance is set by SparseAttentionModule._init_sparse_method()
# in modelopt/torch/sparsity/attention_sparsity/sparse_attention.py
method = getattr(module, "_sparse_method_instance", None)
if method is not None:
kw["sparsity_n"] = getattr(method, "sparsity_n", 2)
kw["sparsity_m"] = getattr(method, "sparsity_m", 4)
kw["num_sink_tokens"] = getattr(method, "num_sink_tokens", 0)
kw["dense_window_size"] = getattr(method, "dense_window_size", 64)
# Sparse attention params
method = getattr(module, "_sparse_method_instance", None)

# N:M sparse softmax: prefill only (no perf benefit for decode)
if method is not None and not is_decode and getattr(module, "_apply_sparse_nm", False):
kw["sparsity_n"] = method.sparsity_n
kw["sparsity_m"] = method.sparsity_m
kw["num_sink_tokens"] = method.num_sink_tokens
kw["dense_window_size"] = method.dense_window_size

# Skip-softmax: applies to both prefill and decode
if method is not None and getattr(module, "_apply_skip_softmax", False):
if method.skip_softmax_threshold:
kw["skip_softmax_threshold"] = method.skip_softmax_threshold

o = attention(q, k, v, **kw)

Expand Down
156 changes: 134 additions & 22 deletions modelopt/torch/kernels/triton_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
metadata (b_start_loc, b_seq_len). Supports causal masking and autograd.
"""

import math

import torch
import triton
import triton.language as tl
Expand Down Expand Up @@ -248,6 +250,8 @@ def _attn_fwd(
SPARSITY_M: tl.constexpr = 4, # N:M sparsity — group size (4 or 8)
NUM_SINK_TOKENS: tl.constexpr = 0, # KV positions before this are kept dense (attention sinks)
DENSE_WINDOW_SIZE: tl.constexpr = 64, # Tokens near diagonal kept dense (absolute, BLOCK_N-independent)
APPLY_SKIP_SOFTMAX: tl.constexpr = False, # Skip KV tiles with negligible scores
SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0, # log2(lambda) * sm_scale, pre-scaled for comparison on scaled scores
):
# --- Grid: (batch, num_q_heads, num_q_tiles) ---
# Example: batch=2, num_q_heads=32, seq_len=256, BLOCK_M=128
Expand Down Expand Up @@ -320,26 +324,65 @@ def _attn_fwd(
scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M
)

# --- Online softmax update ---
# 1. Update running max
m_new = tl.maximum(row_max, tl.max(scores, 1))
# 2. Compute unnormalized attention weights
p = tl.math.exp2(scores - m_new[:, None])
l_new = tl.sum(p, 1)
# 3. Correction factor: rescale previous tiles when max changes
correction = tl.math.exp2(row_max - m_new)
row_sum = row_sum * correction + l_new
acc = acc * correction[:, None]

# Load V [BLOCK_N, BLOCK_D] and accumulate: acc += attn_weights @ V
v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :]
v = tl.load(
v_base + v_offs,
mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :],
other=0.0,
)
acc = tl.dot(p.to(v.dtype), v, acc)
row_max = m_new
if APPLY_SKIP_SOFTMAX:
# --- Skip-softmax (BLASST, https://arxiv.org/pdf/2512.12087) ---
#
# Algorithm: During FlashAttention's block-wise computation, we
# maintain a running maximum m_i^(j) across blocks. If a block's
# local maximum ~m_i^(j) is significantly smaller than the running
# maximum m_i^(j):
#
# ~m_i^(j) - m_i^(j) < ln(lambda)
#
# then exp(~m_i^(j) - m_i^(j)) < lambda ≈ 0, meaning the block's
# contribution to the final output is negligible. We skip the
# softmax computation, V load, and BMM2 computation entirely.
#
# The threshold is pre-scaled by qk_scale in the Python wrapper so
# it can be compared directly against scaled scores (matching the
# BLASST reference semantics on unscaled scores).
tile_row_max = tl.max(scores, 1) # [BLOCK_M] — ~m_i^(j) (scaled)
# Per-row: True if row's tile max is negligible vs running max
can_skip = tile_row_max < (row_max + SKIP_THRESHOLD_LOG2)
# Per-tile: skip entire tile only if ALL rows are negligible
skip_tile = tl.min(can_skip.to(tl.int32)) == 1

if not skip_tile:
m_new = tl.maximum(row_max, tile_row_max)
p = tl.math.exp2(scores - m_new[:, None])
l_new = tl.sum(p, 1)
correction = tl.math.exp2(row_max - m_new)
row_sum = row_sum * correction + l_new
acc = acc * correction[:, None]

v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :]
v = tl.load(
v_base + v_offs,
mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :],
other=0.0,
)
acc = tl.dot(p.to(v.dtype), v, acc)
row_max = m_new
# else: tile skipped: no softmax computation, V load, and BMM2 computation
else:
# --- Standard path: no skip check ---
# Online softmax update
m_new = tl.maximum(row_max, tl.max(scores, 1))
p = tl.math.exp2(scores - m_new[:, None])
l_new = tl.sum(p, 1)
correction = tl.math.exp2(row_max - m_new)
row_sum = row_sum * correction + l_new
acc = acc * correction[:, None]

# Load V and accumulate
v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :]
v = tl.load(
v_base + v_offs,
mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :],
other=0.0,
)
acc = tl.dot(p.to(v.dtype), v, acc)
row_max = m_new
Comment on lines +327 to +385
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems one path for APPLY_SKIP_SOFTMAX=True and one for False. The standard path and the "not all_skip" branch inside the skip path are nearly identical. Since APPLY_SKIP_SOFTMAX is a tl.constexpr, Triton will compile-time eliminate the dead branch, but the source code duplication makes maintenance harder.

Consider structuring as:

# Skip check (compiled out when APPLY_SKIP_SOFTMAX=False)                                                                                                                                                                                                               
  do_process = True                                                                                                                                                                                                                                                       
  if APPLY_SKIP_SOFTMAX:                                                                                                                                                                                                                                                  
      tile_row_max = tl.max(scores, 1)                                                                                                                                                                                                                                    
      can_skip = tile_row_max < (row_max + SKIP_THRESHOLD_LOG2)                                                                                                                                                                                                           
      all_skip = tl.min(can_skip.to(tl.int32)) == 1                                                                                                                                                                                                                       
      if all_skip:                                                                                                                                                                                                                                                        
          do_process = False                                                                                                                                                                                                                                              
                                                                                                                                                                                                                                                                          
  if do_process:                                                                                                                                                                                                                                                          
      # Single copy of softmax update + V accumulation                                                                                                                                                                                                                    
      ...                                                                                                                                                                                                                                                                 
      if APPLY_SKIP_SOFTMAX:                                                                                                                                                                                                                                              
          p = tl.where(can_skip[:, None], 0.0, p)                                                                                                                                                                                                                         
      ...   

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good observation. The duplication is intentional. I considered this refactoring but Triton doesn't support reassigning a variable from a tl.constexpr branch and using it as a runtime condition. The reassignment is silently ignored. If we remove tl.constexpr to work around this, the runtime if/else adds branch divergence overhead on every attention tile iteration.


# --- Final normalization: output = acc / row_sum ---
acc = acc / row_sum[:, None]
Expand Down Expand Up @@ -440,6 +483,8 @@ def _attn_bwd_dq(
SPARSITY_M: tl.constexpr = 4,
NUM_SINK_TOKENS: tl.constexpr = 0,
DENSE_WINDOW_SIZE: tl.constexpr = 64,
APPLY_SKIP_SOFTMAX: tl.constexpr = False,
SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0,
):
"""Phase 3 of backward: compute dQ for one Q tile, looping over KV tiles.

Expand Down Expand Up @@ -523,6 +568,16 @@ def _attn_bwd_dq(

p = tl.math.exp2(scores - lse[:, None])

# Skip-softmax backward: zero out P for rows with negligible contribution.
# Per-row using final LSE because forward/backward tile sizes may differ
# (forward autotunes BLOCK_N; backward uses a fixed size), so per-tile
# skip masks from forward wouldn't align. LSE >= any intermediate running
# max, so this conservatively zeros out at least what forward skipped.
if APPLY_SKIP_SOFTMAX:
tile_row_max = tl.max(scores, 1)
can_skip = tile_row_max < (lse + SKIP_THRESHOLD_LOG2)
p = tl.where(can_skip[:, None], 0.0, p)

# dP = dO @ V^T, dS = P * (dP - delta), dQ += dS @ K
dp = tl.dot(do, tl.trans(v))
ds = p * (dp - row_delta[:, None])
Expand Down Expand Up @@ -574,6 +629,8 @@ def _attn_bwd_dkdv(
SPARSITY_M: tl.constexpr = 4,
NUM_SINK_TOKENS: tl.constexpr = 0,
DENSE_WINDOW_SIZE: tl.constexpr = 64,
APPLY_SKIP_SOFTMAX: tl.constexpr = False,
SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0,
):
"""Phase 2 of backward: compute dK, dV for one KV tile.

Expand Down Expand Up @@ -665,6 +722,16 @@ def _attn_bwd_dkdv(

p = tl.math.exp2(scores - lse[:, None])

# Skip-softmax backward: zero out P for rows with negligible contribution.
# Per-row using final LSE because forward/backward tile sizes may differ
# (forward autotunes BLOCK_N; backward uses a fixed size), so per-tile
# skip masks from forward wouldn't align. LSE >= any intermediate running
# max, so this conservatively zeros out at least what forward skipped.
if APPLY_SKIP_SOFTMAX:
tile_row_max = tl.max(scores, 1)
can_skip = tile_row_max < (lse + SKIP_THRESHOLD_LOG2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So forward and backward is different — forward skips based on row_max, backward on lse.

The lse >= row_max always holds, so the backward threshold is strictly looser — it will skip fewer tiles than the forward. This means the backward may compute gradients for tiles that were skipped in the forward.

Is this intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. This is intentional. It's not trivial to exactly reproduce the forward skip decisions because:

  • Forward uses the running max at each tile step (not saved).
  • Forward and backward may use different BLOCK_N (forward autotunes, backward uses fixed 64).

p = tl.where(can_skip[:, None], 0.0, p)

# dV += P^T @ dO
dv += tl.dot(tl.trans(p.to(do_tile.dtype)), do_tile)
# dS = P * (dO @ V^T - delta), dK += dS^T @ Q
Expand Down Expand Up @@ -700,6 +767,7 @@ def forward(
sparsity_m,
num_sink_tokens,
dense_window_size,
skip_softmax_threshold,
):
HEAD_DIM = q.shape[2]
num_q_heads = q.shape[1]
Expand All @@ -720,6 +788,17 @@ def forward(
# Triton tiles must be powers of 2; pad head dim
BLOCK_D = triton.next_power_of_2(HEAD_DIM)

# Skip-softmax: convert threshold to scaled log2 space for the kernel.
# The BLASST reference (https://arxiv.org/pdf/2512.12087) checks
# ln(lambda) on unscaled scores. Our kernel works in log2-scaled space
# (scores pre-multiplied by qk_scale = sm_scale * LOG2E), so we
# pre-scale: threshold_scaled = log2(lambda) * sm_scale.
apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
if apply_skip:
skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale
else:
skip_threshold_log2 = 0.0
Comment on lines +791 to +800
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Don't multiply the skip cutoff by sm_scale again.

scores already live in raw_score * sm_scale * log2(e) space via qk_scale, so the contribution cutoff in that same space is just log2(lambda). Multiplying by sm_scale here makes the effective threshold head-dim dependent and much looser than the documented fraction.

🐛 Proposed fix
-        # pre-scale: threshold_scaled = log2(lambda) * sm_scale.
+        # In kernel space, score deltas already include `sm_scale * log2(e)`,
+        # so the contribution cutoff is simply `log2(lambda)`.
         apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
         if apply_skip:
-            skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale
+            skip_threshold_log2 = math.log2(skip_softmax_threshold)
         else:
             skip_threshold_log2 = 0.0
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 581 - 590, The skip-softmax
threshold is being over-scaled: when computing skip_threshold_log2 inside the
function handling skip_softmax_threshold, do not multiply
math.log2(skip_softmax_threshold) by sm_scale because scores are already in
qk_scale (raw_score * sm_scale * LOG2E) space; set skip_threshold_log2 =
math.log2(skip_softmax_threshold) when apply_skip is true (leave the else branch
as 0.0), and keep references to skip_softmax_threshold, skip_threshold_log2,
sm_scale, qk_scale, and scores to locate and update the calculation.


o = torch.empty_like(q)
lse = torch.empty(q.shape[0], num_q_heads, device=q.device, dtype=torch.float32)

Expand Down Expand Up @@ -758,6 +837,8 @@ def grid(META):
SPARSITY_M=sparsity_m,
NUM_SINK_TOKENS=num_sink_tokens,
DENSE_WINDOW_SIZE=dense_window_size,
APPLY_SKIP_SOFTMAX=apply_skip,
SKIP_THRESHOLD_LOG2=skip_threshold_log2,
# BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune
)

Expand All @@ -776,6 +857,8 @@ def grid(META):
ctx.sparsity_m = sparsity_m
ctx.num_sink_tokens = num_sink_tokens
ctx.dense_window_size = dense_window_size
ctx.apply_skip = apply_skip
ctx.skip_threshold_log2 = skip_threshold_log2
return o

@staticmethod
Expand Down Expand Up @@ -854,6 +937,8 @@ def backward(ctx, grad_output):
SPARSITY_M=ctx.sparsity_m,
NUM_SINK_TOKENS=ctx.num_sink_tokens,
DENSE_WINDOW_SIZE=ctx.dense_window_size,
APPLY_SKIP_SOFTMAX=ctx.apply_skip,
SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2,
num_warps=num_warps,
num_stages=1,
)
Expand All @@ -877,11 +962,30 @@ def backward(ctx, grad_output):
SPARSITY_M=ctx.sparsity_m,
NUM_SINK_TOKENS=ctx.num_sink_tokens,
DENSE_WINDOW_SIZE=ctx.dense_window_size,
APPLY_SKIP_SOFTMAX=ctx.apply_skip,
SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2,
num_warps=num_warps,
num_stages=1,
)

return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
return (
dq,
dk,
dv,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)


def attention(
Expand All @@ -901,8 +1005,9 @@ def attention(
sparsity_m: int = 4,
num_sink_tokens: int = 0,
dense_window_size: int = 64,
skip_softmax_threshold: float | None = None,
) -> torch.Tensor:
"""Variable-length flash attention with GQA, autograd, and optional N:M sparse softmax.
"""Variable-length flash attention with GQA, autograd, and optional N:M sparse softmax and skip-softmax.

Args:
q: [total_q_tokens, num_q_heads, head_dim]
Expand All @@ -926,6 +1031,12 @@ def attention(
dense_window_size: Tokens near the query diagonal kept dense (local
attention window). Absolute token count, BLOCK_N-independent.
Default 64 (one reference block).
skip_softmax_threshold: BLASST threshold lambda
(https://arxiv.org/pdf/2512.12087). Skip KV tiles where
``exp(tile_max - running_max) < lambda``, meaning the tile's
softmax contribution is negligible. Tiles are skipped entirely
(no softmax, V load, or BMM2). The threshold is applied on
unscaled scores. Set to ``None`` or ``0`` to disable.

Returns:
Output tensor [total_q_tokens, num_q_heads, head_dim].
Expand All @@ -947,6 +1058,7 @@ def attention(
sparsity_m,
num_sink_tokens,
dense_window_size,
skip_softmax_threshold,
)


Expand Down
25 changes: 25 additions & 0 deletions modelopt/torch/sparsity/attention_sparsity/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig):
),
)

skip_softmax_threshold: float = ModeloptField(
default=0.1,
title="Skip-softmax threshold.",
description=(
"Tiles contributing less than this fraction are skipped entirely. "
"Only used by triton_skip_softmax. Typical values: 1e-3 to 1e-1. "
"Set to 0 to disable."
),
)

@field_validator("method")
@classmethod
def validate_method(cls, v):
Expand Down Expand Up @@ -528,9 +538,24 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig):
}


# Default skip-softmax configuration for Triton kernel
SKIP_SOFTMAX_TRITON_DEFAULT = {
"sparse_cfg": {
"*attn*": {
"method": "triton_skip_softmax",
"skip_softmax_threshold": 0.1,
"backend": "triton",
"enable": True,
},
"default": {"enable": False},
},
}


__all__ = [
"SKIP_SOFTMAX_CALIB",
"SKIP_SOFTMAX_DEFAULT",
"SKIP_SOFTMAX_TRITON_DEFAULT",
"SPARSE_SOFTMAX_DEFAULT",
"CalibrationConfig",
"FlashSkipSoftmaxConfig",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@
]

# Import method implementations to trigger registration
from . import flash_skip_softmax, triton_sparse_softmax
from . import flash_skip_softmax, triton_skip_softmax, triton_sparse_softmax
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,7 @@ def apply_sparsity(
Returns:
Masked attention scores with sparse elements set to -inf
"""
raise NotImplementedError(
f"{type(self).__name__} does not implement apply_sparsity. "
"Sparsity may be fused into the kernel (Triton backend)."
)
raise NotImplementedError(f"{type(self).__name__} does not implement apply_sparsity.")

def get_sparse_context(self, module: torch.nn.Module):
"""Return a context manager that activates this method's sparsity during forward.
Expand Down
Loading
Loading