Skip to content

Commit 59849a1

Browse files
committed
Fix the sm_scale issue
Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent 5ad7cdf commit 59849a1

1 file changed

Lines changed: 39 additions & 22 deletions

File tree

modelopt/torch/kernels/triton_fa.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
metadata (b_start_loc, b_seq_len). Supports causal masking and autograd.
2424
"""
2525

26+
import math
27+
2628
import torch
2729
import triton
2830
import triton.language as tl
@@ -106,7 +108,7 @@ def _attn_fwd(
106108
HEAD_DIM: tl.constexpr, # Actual head dimension (for d_mask)
107109
STORE_LSE: tl.constexpr, # Whether to save LSE for backward pass
108110
APPLY_SKIP_SOFTMAX: tl.constexpr = False, # Skip KV tiles with negligible scores
109-
SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0, # log2(threshold) for skip decision
111+
SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0, # log2(lambda) * sm_scale, pre-scaled for comparison on scaled scores
110112
):
111113
# --- Grid: (batch, num_q_heads, num_q_tiles) ---
112114
# Example: batch=2, num_q_heads=32, seq_len=256, BLOCK_M=128
@@ -165,25 +167,36 @@ def _attn_fwd(
165167
scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL)
166168

167169
if APPLY_SKIP_SOFTMAX:
168-
# --- Skip-softmax path: check tile, skip V load if all rows negligible ---
169-
# Compute tile row max once — reused for both skip check and softmax update
170-
tile_row_max = tl.max(scores, 1) # [BLOCK_M]
170+
# --- Skip-softmax (BLASST, https://arxiv.org/pdf/2512.12087) ---
171+
#
172+
# Algorithm: During FlashAttention's block-wise computation, we
173+
# maintain a running maximum m_i^(j) across blocks. If a block's
174+
# local maximum ~m_i^(j) is significantly smaller than the running
175+
# maximum m_i^(j):
176+
#
177+
# ~m_i^(j) - m_i^(j) < ln(lambda)
178+
#
179+
# then exp(~m_i^(j) - m_i^(j)) < lambda ≈ 0, meaning the block's
180+
# contribution to the final output is negligible. We skip the
181+
# softmax computation, V load, and BMM2 computation entirely.
182+
#
183+
# The threshold is pre-scaled by qk_scale in the Python wrapper so
184+
# it can be compared directly against scaled scores (matching the
185+
# BLASST reference semantics on unscaled scores).
186+
tile_row_max = tl.max(scores, 1) # [BLOCK_M] — ~m_i^(j) (scaled)
187+
# Per-row: True if row's tile max is negligible vs running max
171188
can_skip = tile_row_max < (row_max + SKIP_THRESHOLD_LOG2)
172-
all_skip = tl.min(can_skip.to(tl.int32)) == 1
189+
# Per-tile: skip entire tile only if ALL rows are negligible
190+
skip_tile = tl.min(can_skip.to(tl.int32)) == 1
173191

174-
if not all_skip:
175-
# Online softmax update (reuses tile_row_max — no second tl.max)
176-
# For skipped rows: tile_row_max < row_max, so m_new = row_max (no change)
192+
if not skip_tile:
177193
m_new = tl.maximum(row_max, tile_row_max)
178194
p = tl.math.exp2(scores - m_new[:, None])
179-
# Zero out skipped rows (instead of masking scores and recomputing max)
180-
p = tl.where(can_skip[:, None], 0.0, p)
181195
l_new = tl.sum(p, 1)
182196
correction = tl.math.exp2(row_max - m_new)
183197
row_sum = row_sum * correction + l_new
184198
acc = acc * correction[:, None]
185199

186-
# Load V and accumulate
187200
v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :]
188201
v = tl.load(
189202
v_base + v_offs,
@@ -192,7 +205,7 @@ def _attn_fwd(
192205
)
193206
acc = tl.dot(p.to(v.dtype), v, acc)
194207
row_max = m_new
195-
# else: all rows negligible — skip V load, softmax update, accumulation
208+
# else: tile skipped: no softmax computation, V load, and BMM2 computation
196209
else:
197210
# --- Standard path: no skip check ---
198211
# Online softmax update
@@ -565,12 +578,14 @@ def forward(
565578
# Triton tiles must be powers of 2; pad head dim
566579
BLOCK_D = triton.next_power_of_2(HEAD_DIM)
567580

568-
# Skip-softmax: convert threshold to log2 space for the kernel
581+
# Skip-softmax: convert threshold to scaled log2 space for the kernel.
582+
# The BLASST reference (https://arxiv.org/pdf/2512.12087) checks
583+
# ln(lambda) on unscaled scores. Our kernel works in log2-scaled space
584+
# (scores pre-multiplied by qk_scale = sm_scale * LOG2E), so we
585+
# pre-scale: threshold_scaled = log2(lambda) * sm_scale.
569586
apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
570587
if apply_skip:
571-
import math
572-
573-
skip_threshold_log2 = math.log2(skip_softmax_threshold)
588+
skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale
574589
else:
575590
skip_threshold_log2 = 0.0
576591

@@ -759,12 +774,14 @@ def attention(
759774
b_start_loc_k: [batch] start offset for K/V (None = same as Q).
760775
b_seq_len_k: [batch] length for K/V (None = same as Q).
761776
max_input_len_k: Maximum K/V sequence length (None = same as Q).
762-
skip_softmax_threshold: Skip KV tiles whose max attention score is
763-
below ``running_max * threshold`` for all Q rows. This is an
764-
approximation that trades accuracy for speed — tiles with
765-
negligible softmax contributions are skipped entirely (no V
766-
load or accumulation). Set to ``None`` or ``0`` to disable.
767-
Typical values: 1e-3 to 1e-1. The backward pass re-applies
777+
skip_softmax_threshold: BLASST threshold lambda
778+
(https://arxiv.org/pdf/2512.12087). Skip KV tiles where
779+
``exp(tile_max - running_max) < lambda``, meaning the tile's
780+
softmax contribution is negligible. Tiles are skipped entirely
781+
(no softmax, V load, or BMM2). The threshold is applied on
782+
unscaled scores (sm_scale-independent). Set to ``None`` or
783+
``0`` to disable. Typical values: 1e-3 to 1e-1. The backward
784+
pass re-applies
768785
the same skip decision using the saved LSE for consistency.
769786
770787
Returns:

0 commit comments

Comments
 (0)