2323metadata (b_start_loc, b_seq_len). Supports causal masking and autograd.
2424"""
2525
26+ import math
27+
2628import torch
2729import triton
2830import triton .language as tl
@@ -106,7 +108,7 @@ def _attn_fwd(
106108 HEAD_DIM : tl .constexpr , # Actual head dimension (for d_mask)
107109 STORE_LSE : tl .constexpr , # Whether to save LSE for backward pass
108110 APPLY_SKIP_SOFTMAX : tl .constexpr = False , # Skip KV tiles with negligible scores
109- SKIP_THRESHOLD_LOG2 : tl .constexpr = 0.0 , # log2(threshold) for skip decision
111+ SKIP_THRESHOLD_LOG2 : tl .constexpr = 0.0 , # log2(lambda) * sm_scale, pre-scaled for comparison on scaled scores
110112):
111113 # --- Grid: (batch, num_q_heads, num_q_tiles) ---
112114 # Example: batch=2, num_q_heads=32, seq_len=256, BLOCK_M=128
@@ -165,25 +167,36 @@ def _attn_fwd(
165167 scores = _apply_mask (scores , q_pos , kv_pos , seq_len_q , seq_len_kv , kv_start , IS_CAUSAL )
166168
167169 if APPLY_SKIP_SOFTMAX :
168- # --- Skip-softmax path: check tile, skip V load if all rows negligible ---
169- # Compute tile row max once — reused for both skip check and softmax update
170- tile_row_max = tl .max (scores , 1 ) # [BLOCK_M]
170+ # --- Skip-softmax (BLASST, https://arxiv.org/pdf/2512.12087) ---
171+ #
172+ # Algorithm: During FlashAttention's block-wise computation, we
173+ # maintain a running maximum m_i^(j) across blocks. If a block's
174+ # local maximum ~m_i^(j) is significantly smaller than the running
175+ # maximum m_i^(j):
176+ #
177+ # ~m_i^(j) - m_i^(j) < ln(lambda)
178+ #
179+ # then exp(~m_i^(j) - m_i^(j)) < lambda ≈ 0, meaning the block's
180+ # contribution to the final output is negligible. We skip the
181+ # softmax computation, V load, and BMM2 computation entirely.
182+ #
183+ # The threshold is pre-scaled by qk_scale in the Python wrapper so
184+ # it can be compared directly against scaled scores (matching the
185+ # BLASST reference semantics on unscaled scores).
186+ tile_row_max = tl .max (scores , 1 ) # [BLOCK_M] — ~m_i^(j) (scaled)
187+ # Per-row: True if row's tile max is negligible vs running max
171188 can_skip = tile_row_max < (row_max + SKIP_THRESHOLD_LOG2 )
172- all_skip = tl .min (can_skip .to (tl .int32 )) == 1
189+ # Per-tile: skip entire tile only if ALL rows are negligible
190+ skip_tile = tl .min (can_skip .to (tl .int32 )) == 1
173191
174- if not all_skip :
175- # Online softmax update (reuses tile_row_max — no second tl.max)
176- # For skipped rows: tile_row_max < row_max, so m_new = row_max (no change)
192+ if not skip_tile :
177193 m_new = tl .maximum (row_max , tile_row_max )
178194 p = tl .math .exp2 (scores - m_new [:, None ])
179- # Zero out skipped rows (instead of masking scores and recomputing max)
180- p = tl .where (can_skip [:, None ], 0.0 , p )
181195 l_new = tl .sum (p , 1 )
182196 correction = tl .math .exp2 (row_max - m_new )
183197 row_sum = row_sum * correction + l_new
184198 acc = acc * correction [:, None ]
185199
186- # Load V and accumulate
187200 v_offs = (kv_offset + kv_start + kv_pos [:, None ]) * stride_vbs + dim_pos [None , :]
188201 v = tl .load (
189202 v_base + v_offs ,
@@ -192,7 +205,7 @@ def _attn_fwd(
192205 )
193206 acc = tl .dot (p .to (v .dtype ), v , acc )
194207 row_max = m_new
195- # else: all rows negligible — skip V load, softmax update, accumulation
208+ # else: tile skipped: no softmax computation, V load, and BMM2 computation
196209 else :
197210 # --- Standard path: no skip check ---
198211 # Online softmax update
@@ -565,12 +578,14 @@ def forward(
565578 # Triton tiles must be powers of 2; pad head dim
566579 BLOCK_D = triton .next_power_of_2 (HEAD_DIM )
567580
568- # Skip-softmax: convert threshold to log2 space for the kernel
581+ # Skip-softmax: convert threshold to scaled log2 space for the kernel.
582+ # The BLASST reference (https://arxiv.org/pdf/2512.12087) checks
583+ # ln(lambda) on unscaled scores. Our kernel works in log2-scaled space
584+ # (scores pre-multiplied by qk_scale = sm_scale * LOG2E), so we
585+ # pre-scale: threshold_scaled = log2(lambda) * sm_scale.
569586 apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
570587 if apply_skip :
571- import math
572-
573- skip_threshold_log2 = math .log2 (skip_softmax_threshold )
588+ skip_threshold_log2 = math .log2 (skip_softmax_threshold ) * sm_scale
574589 else :
575590 skip_threshold_log2 = 0.0
576591
@@ -759,12 +774,14 @@ def attention(
759774 b_start_loc_k: [batch] start offset for K/V (None = same as Q).
760775 b_seq_len_k: [batch] length for K/V (None = same as Q).
761776 max_input_len_k: Maximum K/V sequence length (None = same as Q).
762- skip_softmax_threshold: Skip KV tiles whose max attention score is
763- below ``running_max * threshold`` for all Q rows. This is an
764- approximation that trades accuracy for speed — tiles with
765- negligible softmax contributions are skipped entirely (no V
766- load or accumulation). Set to ``None`` or ``0`` to disable.
767- Typical values: 1e-3 to 1e-1. The backward pass re-applies
777+ skip_softmax_threshold: BLASST threshold lambda
778+ (https://arxiv.org/pdf/2512.12087). Skip KV tiles where
779+ ``exp(tile_max - running_max) < lambda``, meaning the tile's
780+ softmax contribution is negligible. Tiles are skipped entirely
781+ (no softmax, V load, or BMM2). The threshold is applied on
782+ unscaled scores (sm_scale-independent). Set to ``None`` or
783+ ``0`` to disable. Typical values: 1e-3 to 1e-1. The backward
784+ pass re-applies
768785 the same skip decision using the saved LSE for consistency.
769786
770787 Returns:
0 commit comments