-
Notifications
You must be signed in to change notification settings - Fork 318
[3/n] Add skip-softmax to Triton flash attention kernel #1081
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
60be93f
420251e
0fbbf97
724be3d
188f4d9
7d76af8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,8 @@ | |
| metadata (b_start_loc, b_seq_len). Supports causal masking and autograd. | ||
| """ | ||
|
|
||
| import math | ||
|
|
||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
@@ -248,6 +250,8 @@ def _attn_fwd( | |
| SPARSITY_M: tl.constexpr = 4, # N:M sparsity — group size (4 or 8) | ||
| NUM_SINK_TOKENS: tl.constexpr = 0, # KV positions before this are kept dense (attention sinks) | ||
| DENSE_WINDOW_SIZE: tl.constexpr = 64, # Tokens near diagonal kept dense (absolute, BLOCK_N-independent) | ||
| APPLY_SKIP_SOFTMAX: tl.constexpr = False, # Skip KV tiles with negligible scores | ||
| SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0, # log2(lambda) * sm_scale, pre-scaled for comparison on scaled scores | ||
| ): | ||
| # --- Grid: (batch, num_q_heads, num_q_tiles) --- | ||
| # Example: batch=2, num_q_heads=32, seq_len=256, BLOCK_M=128 | ||
|
|
@@ -320,26 +324,65 @@ def _attn_fwd( | |
| scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M | ||
| ) | ||
|
|
||
| # --- Online softmax update --- | ||
| # 1. Update running max | ||
| m_new = tl.maximum(row_max, tl.max(scores, 1)) | ||
| # 2. Compute unnormalized attention weights | ||
| p = tl.math.exp2(scores - m_new[:, None]) | ||
| l_new = tl.sum(p, 1) | ||
| # 3. Correction factor: rescale previous tiles when max changes | ||
| correction = tl.math.exp2(row_max - m_new) | ||
| row_sum = row_sum * correction + l_new | ||
| acc = acc * correction[:, None] | ||
|
|
||
| # Load V [BLOCK_N, BLOCK_D] and accumulate: acc += attn_weights @ V | ||
| v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] | ||
| v = tl.load( | ||
| v_base + v_offs, | ||
| mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], | ||
| other=0.0, | ||
| ) | ||
| acc = tl.dot(p.to(v.dtype), v, acc) | ||
| row_max = m_new | ||
| if APPLY_SKIP_SOFTMAX: | ||
| # --- Skip-softmax (BLASST, https://arxiv.org/pdf/2512.12087) --- | ||
| # | ||
| # Algorithm: During FlashAttention's block-wise computation, we | ||
| # maintain a running maximum m_i^(j) across blocks. If a block's | ||
| # local maximum ~m_i^(j) is significantly smaller than the running | ||
| # maximum m_i^(j): | ||
| # | ||
| # ~m_i^(j) - m_i^(j) < ln(lambda) | ||
| # | ||
| # then exp(~m_i^(j) - m_i^(j)) < lambda ≈ 0, meaning the block's | ||
| # contribution to the final output is negligible. We skip the | ||
| # softmax computation, V load, and BMM2 computation entirely. | ||
| # | ||
| # The threshold is pre-scaled by qk_scale in the Python wrapper so | ||
| # it can be compared directly against scaled scores (matching the | ||
| # BLASST reference semantics on unscaled scores). | ||
| tile_row_max = tl.max(scores, 1) # [BLOCK_M] — ~m_i^(j) (scaled) | ||
| # Per-row: True if row's tile max is negligible vs running max | ||
| can_skip = tile_row_max < (row_max + SKIP_THRESHOLD_LOG2) | ||
| # Per-tile: skip entire tile only if ALL rows are negligible | ||
| skip_tile = tl.min(can_skip.to(tl.int32)) == 1 | ||
|
|
||
| if not skip_tile: | ||
| m_new = tl.maximum(row_max, tile_row_max) | ||
| p = tl.math.exp2(scores - m_new[:, None]) | ||
| l_new = tl.sum(p, 1) | ||
| correction = tl.math.exp2(row_max - m_new) | ||
| row_sum = row_sum * correction + l_new | ||
| acc = acc * correction[:, None] | ||
|
|
||
| v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] | ||
| v = tl.load( | ||
| v_base + v_offs, | ||
| mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], | ||
| other=0.0, | ||
| ) | ||
| acc = tl.dot(p.to(v.dtype), v, acc) | ||
| row_max = m_new | ||
| # else: tile skipped: no softmax computation, V load, and BMM2 computation | ||
| else: | ||
| # --- Standard path: no skip check --- | ||
| # Online softmax update | ||
| m_new = tl.maximum(row_max, tl.max(scores, 1)) | ||
| p = tl.math.exp2(scores - m_new[:, None]) | ||
| l_new = tl.sum(p, 1) | ||
| correction = tl.math.exp2(row_max - m_new) | ||
| row_sum = row_sum * correction + l_new | ||
| acc = acc * correction[:, None] | ||
|
|
||
| # Load V and accumulate | ||
| v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] | ||
| v = tl.load( | ||
| v_base + v_offs, | ||
| mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], | ||
| other=0.0, | ||
| ) | ||
| acc = tl.dot(p.to(v.dtype), v, acc) | ||
| row_max = m_new | ||
|
|
||
| # --- Final normalization: output = acc / row_sum --- | ||
| acc = acc / row_sum[:, None] | ||
|
|
@@ -440,6 +483,8 @@ def _attn_bwd_dq( | |
| SPARSITY_M: tl.constexpr = 4, | ||
| NUM_SINK_TOKENS: tl.constexpr = 0, | ||
| DENSE_WINDOW_SIZE: tl.constexpr = 64, | ||
| APPLY_SKIP_SOFTMAX: tl.constexpr = False, | ||
| SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0, | ||
| ): | ||
| """Phase 3 of backward: compute dQ for one Q tile, looping over KV tiles. | ||
|
|
||
|
|
@@ -523,6 +568,16 @@ def _attn_bwd_dq( | |
|
|
||
| p = tl.math.exp2(scores - lse[:, None]) | ||
|
|
||
| # Skip-softmax backward: zero out P for rows with negligible contribution. | ||
| # Per-row using final LSE because forward/backward tile sizes may differ | ||
| # (forward autotunes BLOCK_N; backward uses a fixed size), so per-tile | ||
| # skip masks from forward wouldn't align. LSE >= any intermediate running | ||
| # max, so this conservatively zeros out at least what forward skipped. | ||
| if APPLY_SKIP_SOFTMAX: | ||
| tile_row_max = tl.max(scores, 1) | ||
| can_skip = tile_row_max < (lse + SKIP_THRESHOLD_LOG2) | ||
| p = tl.where(can_skip[:, None], 0.0, p) | ||
|
|
||
| # dP = dO @ V^T, dS = P * (dP - delta), dQ += dS @ K | ||
| dp = tl.dot(do, tl.trans(v)) | ||
| ds = p * (dp - row_delta[:, None]) | ||
|
|
@@ -574,6 +629,8 @@ def _attn_bwd_dkdv( | |
| SPARSITY_M: tl.constexpr = 4, | ||
| NUM_SINK_TOKENS: tl.constexpr = 0, | ||
| DENSE_WINDOW_SIZE: tl.constexpr = 64, | ||
| APPLY_SKIP_SOFTMAX: tl.constexpr = False, | ||
| SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0, | ||
| ): | ||
| """Phase 2 of backward: compute dK, dV for one KV tile. | ||
|
|
||
|
|
@@ -665,6 +722,16 @@ def _attn_bwd_dkdv( | |
|
|
||
| p = tl.math.exp2(scores - lse[:, None]) | ||
|
|
||
| # Skip-softmax backward: zero out P for rows with negligible contribution. | ||
| # Per-row using final LSE because forward/backward tile sizes may differ | ||
| # (forward autotunes BLOCK_N; backward uses a fixed size), so per-tile | ||
| # skip masks from forward wouldn't align. LSE >= any intermediate running | ||
| # max, so this conservatively zeros out at least what forward skipped. | ||
| if APPLY_SKIP_SOFTMAX: | ||
| tile_row_max = tl.max(scores, 1) | ||
| can_skip = tile_row_max < (lse + SKIP_THRESHOLD_LOG2) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So forward and backward is different — forward skips based on row_max, backward on lse. The lse >= row_max always holds, so the backward threshold is strictly looser — it will skip fewer tiles than the forward. This means the backward may compute gradients for tiles that were skipped in the forward. Is this intentional?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch. This is intentional. It's not trivial to exactly reproduce the forward skip decisions because:
|
||
| p = tl.where(can_skip[:, None], 0.0, p) | ||
|
|
||
| # dV += P^T @ dO | ||
| dv += tl.dot(tl.trans(p.to(do_tile.dtype)), do_tile) | ||
| # dS = P * (dO @ V^T - delta), dK += dS^T @ Q | ||
|
|
@@ -700,6 +767,7 @@ def forward( | |
| sparsity_m, | ||
| num_sink_tokens, | ||
| dense_window_size, | ||
| skip_softmax_threshold, | ||
| ): | ||
| HEAD_DIM = q.shape[2] | ||
| num_q_heads = q.shape[1] | ||
|
|
@@ -720,6 +788,17 @@ def forward( | |
| # Triton tiles must be powers of 2; pad head dim | ||
| BLOCK_D = triton.next_power_of_2(HEAD_DIM) | ||
|
|
||
| # Skip-softmax: convert threshold to scaled log2 space for the kernel. | ||
| # The BLASST reference (https://arxiv.org/pdf/2512.12087) checks | ||
| # ln(lambda) on unscaled scores. Our kernel works in log2-scaled space | ||
| # (scores pre-multiplied by qk_scale = sm_scale * LOG2E), so we | ||
| # pre-scale: threshold_scaled = log2(lambda) * sm_scale. | ||
| apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0 | ||
| if apply_skip: | ||
| skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale | ||
| else: | ||
| skip_threshold_log2 = 0.0 | ||
|
Comment on lines
+791
to
+800
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't multiply the skip cutoff by
🐛 Proposed fix- # pre-scale: threshold_scaled = log2(lambda) * sm_scale.
+ # In kernel space, score deltas already include `sm_scale * log2(e)`,
+ # so the contribution cutoff is simply `log2(lambda)`.
apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
if apply_skip:
- skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale
+ skip_threshold_log2 = math.log2(skip_softmax_threshold)
else:
skip_threshold_log2 = 0.0🤖 Prompt for AI Agents |
||
|
|
||
| o = torch.empty_like(q) | ||
| lse = torch.empty(q.shape[0], num_q_heads, device=q.device, dtype=torch.float32) | ||
|
|
||
|
|
@@ -758,6 +837,8 @@ def grid(META): | |
| SPARSITY_M=sparsity_m, | ||
| NUM_SINK_TOKENS=num_sink_tokens, | ||
| DENSE_WINDOW_SIZE=dense_window_size, | ||
| APPLY_SKIP_SOFTMAX=apply_skip, | ||
| SKIP_THRESHOLD_LOG2=skip_threshold_log2, | ||
| # BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune | ||
| ) | ||
|
|
||
|
|
@@ -776,6 +857,8 @@ def grid(META): | |
| ctx.sparsity_m = sparsity_m | ||
| ctx.num_sink_tokens = num_sink_tokens | ||
| ctx.dense_window_size = dense_window_size | ||
| ctx.apply_skip = apply_skip | ||
| ctx.skip_threshold_log2 = skip_threshold_log2 | ||
| return o | ||
|
|
||
| @staticmethod | ||
|
|
@@ -854,6 +937,8 @@ def backward(ctx, grad_output): | |
| SPARSITY_M=ctx.sparsity_m, | ||
| NUM_SINK_TOKENS=ctx.num_sink_tokens, | ||
| DENSE_WINDOW_SIZE=ctx.dense_window_size, | ||
| APPLY_SKIP_SOFTMAX=ctx.apply_skip, | ||
| SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2, | ||
| num_warps=num_warps, | ||
| num_stages=1, | ||
| ) | ||
|
|
@@ -877,11 +962,30 @@ def backward(ctx, grad_output): | |
| SPARSITY_M=ctx.sparsity_m, | ||
| NUM_SINK_TOKENS=ctx.num_sink_tokens, | ||
| DENSE_WINDOW_SIZE=ctx.dense_window_size, | ||
| APPLY_SKIP_SOFTMAX=ctx.apply_skip, | ||
| SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2, | ||
| num_warps=num_warps, | ||
| num_stages=1, | ||
| ) | ||
|
|
||
| return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None | ||
| return ( | ||
| dq, | ||
| dk, | ||
| dv, | ||
| None, | ||
| None, | ||
| None, | ||
| None, | ||
| None, | ||
| None, | ||
| None, | ||
| None, | ||
| None, | ||
| None, | ||
| None, | ||
| None, | ||
| None, | ||
| ) | ||
|
|
||
|
|
||
| def attention( | ||
|
|
@@ -901,8 +1005,9 @@ def attention( | |
| sparsity_m: int = 4, | ||
| num_sink_tokens: int = 0, | ||
| dense_window_size: int = 64, | ||
| skip_softmax_threshold: float | None = None, | ||
| ) -> torch.Tensor: | ||
| """Variable-length flash attention with GQA, autograd, and optional N:M sparse softmax. | ||
| """Variable-length flash attention with GQA, autograd, and optional N:M sparse softmax and skip-softmax. | ||
|
|
||
| Args: | ||
| q: [total_q_tokens, num_q_heads, head_dim] | ||
|
|
@@ -926,6 +1031,12 @@ def attention( | |
| dense_window_size: Tokens near the query diagonal kept dense (local | ||
| attention window). Absolute token count, BLOCK_N-independent. | ||
| Default 64 (one reference block). | ||
| skip_softmax_threshold: BLASST threshold lambda | ||
| (https://arxiv.org/pdf/2512.12087). Skip KV tiles where | ||
| ``exp(tile_max - running_max) < lambda``, meaning the tile's | ||
| softmax contribution is negligible. Tiles are skipped entirely | ||
| (no softmax, V load, or BMM2). The threshold is applied on | ||
| unscaled scores. Set to ``None`` or ``0`` to disable. | ||
|
|
||
| Returns: | ||
| Output tensor [total_q_tokens, num_q_heads, head_dim]. | ||
|
|
@@ -947,6 +1058,7 @@ def attention( | |
| sparsity_m, | ||
| num_sink_tokens, | ||
| dense_window_size, | ||
| skip_softmax_threshold, | ||
| ) | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems one path for APPLY_SKIP_SOFTMAX=True and one for False. The standard path and the "not all_skip" branch inside the skip path are nearly identical. Since APPLY_SKIP_SOFTMAX is a tl.constexpr, Triton will compile-time eliminate the dead branch, but the source code duplication makes maintenance harder.
Consider structuring as:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good observation. The duplication is intentional. I considered this refactoring but Triton doesn't support reassigning a variable from a
tl.constexprbranch and using it as a runtime condition. The reassignment is silently ignored. If we removetl.constexprto work around this, the runtime if/else adds branch divergence overhead on every attention tile iteration.