diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c647cf716..cc172bdcf 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -8,6 +8,7 @@ NVIDIA Model Optimizer Changelog - Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics. - Added iterator interface using CalibrationDataReader in ONNX quantization workflow. - Add N:M sparse softmax support to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. +- Add skip-softmax skipping to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. - Enable PTQ workflow for the Step3.5-Flash MoE model with NVFP4 W4A4 + FP8 KV cache quantization. See `modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml `_ for more details. **Bug Fixes** diff --git a/modelopt/torch/kernels/hf_triton_attention.py b/modelopt/torch/kernels/hf_triton_attention.py index afe4852ea..5021d34e3 100644 --- a/modelopt/torch/kernels/hf_triton_attention.py +++ b/modelopt/torch/kernels/hf_triton_attention.py @@ -105,16 +105,20 @@ def triton_attention_forward( kw["b_seq_len_k"] = torch.full((batch,), seq_k, device=device, dtype=torch.int32) kw["max_input_len_k"] = seq_k - # N:M sparse softmax — prefill only (decode should not sparsify KV) - if not is_decode and getattr(module, "_apply_sparse_nm", False): - # _sparse_method_instance is set by SparseAttentionModule._init_sparse_method() - # in modelopt/torch/sparsity/attention_sparsity/sparse_attention.py - method = getattr(module, "_sparse_method_instance", None) - if method is not None: - kw["sparsity_n"] = getattr(method, "sparsity_n", 2) - kw["sparsity_m"] = getattr(method, "sparsity_m", 4) - kw["num_sink_tokens"] = getattr(method, "num_sink_tokens", 0) - kw["dense_window_size"] = getattr(method, "dense_window_size", 64) + # Sparse attention params + method = getattr(module, "_sparse_method_instance", None) + + # N:M sparse softmax: prefill only (no perf benefit for decode) + if method is not None and not is_decode and getattr(module, "_apply_sparse_nm", False): + kw["sparsity_n"] = method.sparsity_n + kw["sparsity_m"] = method.sparsity_m + kw["num_sink_tokens"] = method.num_sink_tokens + kw["dense_window_size"] = method.dense_window_size + + # Skip-softmax: applies to both prefill and decode + if method is not None and getattr(module, "_apply_skip_softmax", False): + if method.skip_softmax_threshold: + kw["skip_softmax_threshold"] = method.skip_softmax_threshold o = attention(q, k, v, **kw) diff --git a/modelopt/torch/kernels/triton_fa.py b/modelopt/torch/kernels/triton_fa.py index 69602d8d6..8d3b11f1a 100644 --- a/modelopt/torch/kernels/triton_fa.py +++ b/modelopt/torch/kernels/triton_fa.py @@ -23,6 +23,8 @@ metadata (b_start_loc, b_seq_len). Supports causal masking and autograd. """ +import math + import torch import triton import triton.language as tl @@ -248,6 +250,8 @@ def _attn_fwd( SPARSITY_M: tl.constexpr = 4, # N:M sparsity — group size (4 or 8) NUM_SINK_TOKENS: tl.constexpr = 0, # KV positions before this are kept dense (attention sinks) DENSE_WINDOW_SIZE: tl.constexpr = 64, # Tokens near diagonal kept dense (absolute, BLOCK_N-independent) + APPLY_SKIP_SOFTMAX: tl.constexpr = False, # Skip KV tiles with negligible scores + SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0, # log2(lambda) * sm_scale, pre-scaled for comparison on scaled scores ): # --- Grid: (batch, num_q_heads, num_q_tiles) --- # Example: batch=2, num_q_heads=32, seq_len=256, BLOCK_M=128 @@ -320,26 +324,65 @@ def _attn_fwd( scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M ) - # --- Online softmax update --- - # 1. Update running max - m_new = tl.maximum(row_max, tl.max(scores, 1)) - # 2. Compute unnormalized attention weights - p = tl.math.exp2(scores - m_new[:, None]) - l_new = tl.sum(p, 1) - # 3. Correction factor: rescale previous tiles when max changes - correction = tl.math.exp2(row_max - m_new) - row_sum = row_sum * correction + l_new - acc = acc * correction[:, None] - - # Load V [BLOCK_N, BLOCK_D] and accumulate: acc += attn_weights @ V - v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] - v = tl.load( - v_base + v_offs, - mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], - other=0.0, - ) - acc = tl.dot(p.to(v.dtype), v, acc) - row_max = m_new + if APPLY_SKIP_SOFTMAX: + # --- Skip-softmax (BLASST, https://arxiv.org/pdf/2512.12087) --- + # + # Algorithm: During FlashAttention's block-wise computation, we + # maintain a running maximum m_i^(j) across blocks. If a block's + # local maximum ~m_i^(j) is significantly smaller than the running + # maximum m_i^(j): + # + # ~m_i^(j) - m_i^(j) < ln(lambda) + # + # then exp(~m_i^(j) - m_i^(j)) < lambda ≈ 0, meaning the block's + # contribution to the final output is negligible. We skip the + # softmax computation, V load, and BMM2 computation entirely. + # + # The threshold is pre-scaled by qk_scale in the Python wrapper so + # it can be compared directly against scaled scores (matching the + # BLASST reference semantics on unscaled scores). + tile_row_max = tl.max(scores, 1) # [BLOCK_M] — ~m_i^(j) (scaled) + # Per-row: True if row's tile max is negligible vs running max + can_skip = tile_row_max < (row_max + SKIP_THRESHOLD_LOG2) + # Per-tile: skip entire tile only if ALL rows are negligible + skip_tile = tl.min(can_skip.to(tl.int32)) == 1 + + if not skip_tile: + m_new = tl.maximum(row_max, tile_row_max) + p = tl.math.exp2(scores - m_new[:, None]) + l_new = tl.sum(p, 1) + correction = tl.math.exp2(row_max - m_new) + row_sum = row_sum * correction + l_new + acc = acc * correction[:, None] + + v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] + v = tl.load( + v_base + v_offs, + mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], + other=0.0, + ) + acc = tl.dot(p.to(v.dtype), v, acc) + row_max = m_new + # else: tile skipped: no softmax computation, V load, and BMM2 computation + else: + # --- Standard path: no skip check --- + # Online softmax update + m_new = tl.maximum(row_max, tl.max(scores, 1)) + p = tl.math.exp2(scores - m_new[:, None]) + l_new = tl.sum(p, 1) + correction = tl.math.exp2(row_max - m_new) + row_sum = row_sum * correction + l_new + acc = acc * correction[:, None] + + # Load V and accumulate + v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] + v = tl.load( + v_base + v_offs, + mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], + other=0.0, + ) + acc = tl.dot(p.to(v.dtype), v, acc) + row_max = m_new # --- Final normalization: output = acc / row_sum --- acc = acc / row_sum[:, None] @@ -440,6 +483,8 @@ def _attn_bwd_dq( SPARSITY_M: tl.constexpr = 4, NUM_SINK_TOKENS: tl.constexpr = 0, DENSE_WINDOW_SIZE: tl.constexpr = 64, + APPLY_SKIP_SOFTMAX: tl.constexpr = False, + SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0, ): """Phase 3 of backward: compute dQ for one Q tile, looping over KV tiles. @@ -523,6 +568,16 @@ def _attn_bwd_dq( p = tl.math.exp2(scores - lse[:, None]) + # Skip-softmax backward: zero out P for rows with negligible contribution. + # Per-row using final LSE because forward/backward tile sizes may differ + # (forward autotunes BLOCK_N; backward uses a fixed size), so per-tile + # skip masks from forward wouldn't align. LSE >= any intermediate running + # max, so this conservatively zeros out at least what forward skipped. + if APPLY_SKIP_SOFTMAX: + tile_row_max = tl.max(scores, 1) + can_skip = tile_row_max < (lse + SKIP_THRESHOLD_LOG2) + p = tl.where(can_skip[:, None], 0.0, p) + # dP = dO @ V^T, dS = P * (dP - delta), dQ += dS @ K dp = tl.dot(do, tl.trans(v)) ds = p * (dp - row_delta[:, None]) @@ -574,6 +629,8 @@ def _attn_bwd_dkdv( SPARSITY_M: tl.constexpr = 4, NUM_SINK_TOKENS: tl.constexpr = 0, DENSE_WINDOW_SIZE: tl.constexpr = 64, + APPLY_SKIP_SOFTMAX: tl.constexpr = False, + SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0, ): """Phase 2 of backward: compute dK, dV for one KV tile. @@ -665,6 +722,16 @@ def _attn_bwd_dkdv( p = tl.math.exp2(scores - lse[:, None]) + # Skip-softmax backward: zero out P for rows with negligible contribution. + # Per-row using final LSE because forward/backward tile sizes may differ + # (forward autotunes BLOCK_N; backward uses a fixed size), so per-tile + # skip masks from forward wouldn't align. LSE >= any intermediate running + # max, so this conservatively zeros out at least what forward skipped. + if APPLY_SKIP_SOFTMAX: + tile_row_max = tl.max(scores, 1) + can_skip = tile_row_max < (lse + SKIP_THRESHOLD_LOG2) + p = tl.where(can_skip[:, None], 0.0, p) + # dV += P^T @ dO dv += tl.dot(tl.trans(p.to(do_tile.dtype)), do_tile) # dS = P * (dO @ V^T - delta), dK += dS^T @ Q @@ -700,6 +767,7 @@ def forward( sparsity_m, num_sink_tokens, dense_window_size, + skip_softmax_threshold, ): HEAD_DIM = q.shape[2] num_q_heads = q.shape[1] @@ -720,6 +788,17 @@ def forward( # Triton tiles must be powers of 2; pad head dim BLOCK_D = triton.next_power_of_2(HEAD_DIM) + # Skip-softmax: convert threshold to scaled log2 space for the kernel. + # The BLASST reference (https://arxiv.org/pdf/2512.12087) checks + # ln(lambda) on unscaled scores. Our kernel works in log2-scaled space + # (scores pre-multiplied by qk_scale = sm_scale * LOG2E), so we + # pre-scale: threshold_scaled = log2(lambda) * sm_scale. + apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0 + if apply_skip: + skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale + else: + skip_threshold_log2 = 0.0 + o = torch.empty_like(q) lse = torch.empty(q.shape[0], num_q_heads, device=q.device, dtype=torch.float32) @@ -758,6 +837,8 @@ def grid(META): SPARSITY_M=sparsity_m, NUM_SINK_TOKENS=num_sink_tokens, DENSE_WINDOW_SIZE=dense_window_size, + APPLY_SKIP_SOFTMAX=apply_skip, + SKIP_THRESHOLD_LOG2=skip_threshold_log2, # BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune ) @@ -776,6 +857,8 @@ def grid(META): ctx.sparsity_m = sparsity_m ctx.num_sink_tokens = num_sink_tokens ctx.dense_window_size = dense_window_size + ctx.apply_skip = apply_skip + ctx.skip_threshold_log2 = skip_threshold_log2 return o @staticmethod @@ -854,6 +937,8 @@ def backward(ctx, grad_output): SPARSITY_M=ctx.sparsity_m, NUM_SINK_TOKENS=ctx.num_sink_tokens, DENSE_WINDOW_SIZE=ctx.dense_window_size, + APPLY_SKIP_SOFTMAX=ctx.apply_skip, + SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2, num_warps=num_warps, num_stages=1, ) @@ -877,11 +962,30 @@ def backward(ctx, grad_output): SPARSITY_M=ctx.sparsity_m, NUM_SINK_TOKENS=ctx.num_sink_tokens, DENSE_WINDOW_SIZE=ctx.dense_window_size, + APPLY_SKIP_SOFTMAX=ctx.apply_skip, + SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2, num_warps=num_warps, num_stages=1, ) - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None + return ( + dq, + dk, + dv, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) def attention( @@ -901,8 +1005,9 @@ def attention( sparsity_m: int = 4, num_sink_tokens: int = 0, dense_window_size: int = 64, + skip_softmax_threshold: float | None = None, ) -> torch.Tensor: - """Variable-length flash attention with GQA, autograd, and optional N:M sparse softmax. + """Variable-length flash attention with GQA, autograd, and optional N:M sparse softmax and skip-softmax. Args: q: [total_q_tokens, num_q_heads, head_dim] @@ -926,6 +1031,12 @@ def attention( dense_window_size: Tokens near the query diagonal kept dense (local attention window). Absolute token count, BLOCK_N-independent. Default 64 (one reference block). + skip_softmax_threshold: BLASST threshold lambda + (https://arxiv.org/pdf/2512.12087). Skip KV tiles where + ``exp(tile_max - running_max) < lambda``, meaning the tile's + softmax contribution is negligible. Tiles are skipped entirely + (no softmax, V load, or BMM2). The threshold is applied on + unscaled scores. Set to ``None`` or ``0`` to disable. Returns: Output tensor [total_q_tokens, num_q_heads, head_dim]. @@ -947,6 +1058,7 @@ def attention( sparsity_m, num_sink_tokens, dense_window_size, + skip_softmax_threshold, ) diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 7f9e3d76e..ae20323ad 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -129,6 +129,16 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): ), ) + skip_softmax_threshold: float = ModeloptField( + default=0.1, + title="Skip-softmax threshold.", + description=( + "Tiles contributing less than this fraction are skipped entirely. " + "Only used by triton_skip_softmax. Typical values: 1e-3 to 1e-1. " + "Set to 0 to disable." + ), + ) + @field_validator("method") @classmethod def validate_method(cls, v): @@ -528,9 +538,24 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): } +# Default skip-softmax configuration for Triton kernel +SKIP_SOFTMAX_TRITON_DEFAULT = { + "sparse_cfg": { + "*attn*": { + "method": "triton_skip_softmax", + "skip_softmax_threshold": 0.1, + "backend": "triton", + "enable": True, + }, + "default": {"enable": False}, + }, +} + + __all__ = [ "SKIP_SOFTMAX_CALIB", "SKIP_SOFTMAX_DEFAULT", + "SKIP_SOFTMAX_TRITON_DEFAULT", "SPARSE_SOFTMAX_DEFAULT", "CalibrationConfig", "FlashSkipSoftmaxConfig", diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py index 1bd9a547d..7e40ec648 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py @@ -24,4 +24,4 @@ ] # Import method implementations to trigger registration -from . import flash_skip_softmax, triton_sparse_softmax +from . import flash_skip_softmax, triton_skip_softmax, triton_sparse_softmax diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py index c6a8638b7..8254cd8ff 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -76,10 +76,7 @@ def apply_sparsity( Returns: Masked attention scores with sparse elements set to -inf """ - raise NotImplementedError( - f"{type(self).__name__} does not implement apply_sparsity. " - "Sparsity may be fused into the kernel (Triton backend)." - ) + raise NotImplementedError(f"{type(self).__name__} does not implement apply_sparsity.") def get_sparse_context(self, module: torch.nn.Module): """Return a context manager that activates this method's sparsity during forward. diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py new file mode 100644 index 000000000..4db51e894 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py @@ -0,0 +1,59 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Skip-softmax method for attention via Triton kernel tile skipping.""" + +from contextlib import contextmanager + +from .registry import SparseAttentionMethod, register_sparse_method + + +@register_sparse_method("triton_skip_softmax") +class TritonSkipSoftmaxMethod(SparseAttentionMethod): + """Skip-softmax tile skipping via the Triton flash attention kernel. + + During prefill, KV tiles whose max attention score is far below the + running softmax max are skipped entirely — no V load, no softmax + update, no accumulation. This is a long-context optimization that + benefits sequences with strong attention locality. + + Config params: + skip_softmax_threshold: Tiles contributing less than this fraction + are skipped. Typical values: 1e-3 to 1e-1. Set to 0 to disable. + """ + + def __init__(self, method_config=None): + """Initialize with skip-softmax threshold from config.""" + super().__init__() + method_config = method_config or {} + self.skip_softmax_threshold = method_config.get("skip_softmax_threshold", 0.1) + + @property + def name(self) -> str: + """Method name identifier.""" + return "triton_skip_softmax" + + def get_sparse_context(self, module): + """Return context manager that activates skip-softmax during forward.""" + + @contextmanager + def _skip_softmax_context(): + module._apply_skip_softmax = True + try: + yield + finally: + module._apply_skip_softmax = False + + return _skip_softmax_context() diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py index 5ee2696da..a5174496c 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py @@ -112,20 +112,6 @@ def test_decode_matches_sdpa(self): ref = F.scaled_dot_product_attention(qb, kb, vb, is_causal=False).squeeze(2) torch.testing.assert_close(out[i : i + 1], ref, rtol=1e-3, atol=1e-3) - def test_sparse_disabled_matches_dense(self): - """sparsity_n=0 produces bit-identical output to default (dense).""" - seq_lens = [128, 128] - total = sum(seq_lens) - scale = 1.0 / (64**0.5) - - torch.manual_seed(99) - q, k, v = make_qkv(total, 4, 2, 64) - locs, lens = make_varlen_meta(seq_lens) - - out_dense = attention(q, k, v, locs, lens, 128, softmax_scale=scale) - out_n0 = attention(q, k, v, locs, lens, 128, softmax_scale=scale, sparsity_n=0) - assert torch.equal(out_dense, out_n0) - # --------------------------------------------------------------------------- # Backward correctness (dense) @@ -332,62 +318,3 @@ def test_triton_padded_batch(self, tiny_llama_dir): with torch.no_grad(): logits = model(**inputs).logits assert not torch.isnan(logits).any() and not torch.isinf(logits).any() - - def test_sparse_nm_via_sparsify(self, tiny_llama_dir): - """mtsa.sparsify() with N:M sparse softmax produces finite logits that differ from dense.""" - pytest.importorskip("transformers") - from transformers import AutoModelForCausalLM, AutoTokenizer - - import modelopt.torch.sparsity.attention_sparsity as mtsa - - tok = AutoTokenizer.from_pretrained(tiny_llama_dir) - if tok.pad_token_id is None: - tok.pad_token_id = tok.eos_token_id - # Use a long input (fill max_position_embeddings=64) so sparsity has tiles to prune - ids = torch.randint(1, tok.vocab_size, (1, 64), device="cuda") - - # Dense baseline (triton backend, no sparsity) - model_dense = AutoModelForCausalLM.from_pretrained( - tiny_llama_dir, - attn_implementation="modelopt_triton", - torch_dtype=torch.bfloat16, - device_map="cuda", - ) - model_dense.eval() - with torch.no_grad(): - logits_dense = model_dense(input_ids=ids).logits - del model_dense - - # Sparse via mtsa.sparsify() with dense_window_size=0 to force sparsity on all tiles - sparse_cfg = { - "sparse_cfg": { - "*attn*": { - "method": "triton_sparse_softmax", - "sparsity_n": 2, - "sparsity_m": 4, - "num_sink_tokens": 0, - "dense_window_size": 0, - "backend": "triton", - "enable": True, - }, - "default": {"enable": False}, - }, - } - model_sparse = AutoModelForCausalLM.from_pretrained( - tiny_llama_dir, - torch_dtype=torch.bfloat16, - device_map="cuda", - ) - mtsa.sparsify(model_sparse, sparse_cfg) - assert model_sparse.config._attn_implementation == "modelopt_triton" - model_sparse.eval() - with torch.no_grad(): - logits_sparse = model_sparse(input_ids=ids).logits - - # Sparse output should be finite - assert not torch.isnan(logits_sparse).any(), "NaN in sparse logits" - assert not torch.isinf(logits_sparse).any(), "Inf in sparse logits" - # Sparse output should differ from dense (sparsity changes attention) - assert not torch.allclose(logits_sparse, logits_dense, atol=1e-2), ( - "Sparse logits identical to dense — sparsity may not be applied" - ) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_skip_softmax.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_skip_softmax.py new file mode 100644 index 000000000..21b2a12ca --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_skip_softmax.py @@ -0,0 +1,287 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for skip-softmax (BLASST) on the Triton flash attention kernel.""" + +import pytest +import torch +from conftest import make_varlen_meta + +pytestmark = [ + pytest.mark.filterwarnings("ignore::UserWarning"), + pytest.mark.filterwarnings("ignore::RuntimeWarning"), + pytest.mark.filterwarnings("ignore::DeprecationWarning"), +] + +from modelopt.torch.kernels import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE + +if TRITON_KERNEL_AVAILABLE: + from modelopt.torch.kernels import attention, register_triton_attention + + if register_triton_attention is not None: + register_triton_attention() + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestSkipSoftmax: + """Skip-softmax tile-skipping approximation tests.""" + + def _make_inputs(self, batch=2, seq_len=256, num_heads=4, num_kv_heads=2, head_dim=64): + total = batch * seq_len + torch.manual_seed(77) + q = torch.randn(total, num_heads, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + v = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + locs, lens = make_varlen_meta([seq_len] * batch) + return q, k, v, locs, lens + + def test_disabled_matches_dense(self): + """skip_softmax_threshold=None/0.0 produces bit-identical output to dense.""" + q, k, v, locs, lens = self._make_inputs() + scale = 1.0 / (64**0.5) + out_none = attention(q, k, v, locs, lens, 256, softmax_scale=scale) + out_zero = attention( + q, k, v, locs, lens, 256, softmax_scale=scale, skip_softmax_threshold=0.0 + ) + assert torch.equal(out_none, out_zero) + + def test_small_threshold_close_to_dense(self): + """A small threshold (1e-3) should produce output very close to dense.""" + q, k, v, locs, lens = self._make_inputs() + scale = 1.0 / (64**0.5) + out_dense = attention(q, k, v, locs, lens, 256, softmax_scale=scale) + out_skip = attention( + q, k, v, locs, lens, 256, softmax_scale=scale, skip_softmax_threshold=1e-3 + ) + torch.testing.assert_close(out_skip, out_dense, rtol=5e-2, atol=5e-2) + + def test_large_threshold_differs_from_dense(self): + """A large threshold should produce noticeably different output on spiky data. + + Random data distributes attention uniformly so few tiles are skipped. + Use long sequences with spiky attention (one hot-key per query) to + ensure the BLASST algorithm actually skips negligible tiles. + """ + batch, seq_len, num_heads, head_dim = 1, 4096, 4, 64 + total = batch * seq_len + torch.manual_seed(77) + # Create spiky attention: each query attends strongly to one key + q = torch.zeros(total, num_heads, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn(total, num_heads, head_dim, device="cuda", dtype=torch.float16) + v = torch.randn(total, num_heads, head_dim, device="cuda", dtype=torch.float16) + # Set each query to match a single key so attention is concentrated + for i in range(total): + q[i] = k[i] + locs = torch.zeros(batch, device="cuda", dtype=torch.int32) + lens = torch.full((batch,), seq_len, device="cuda", dtype=torch.int32) + + scale = 1.0 / (head_dim**0.5) + out_dense = attention(q, k, v, locs, lens, seq_len, softmax_scale=scale) + out_skip = attention( + q, k, v, locs, lens, seq_len, softmax_scale=scale, skip_softmax_threshold=0.5 + ) + assert not torch.allclose(out_skip, out_dense, atol=1e-3) + + def test_output_shape_unchanged(self): + """Skip-softmax does not change output shape.""" + q, k, v, locs, lens = self._make_inputs() + scale = 1.0 / (64**0.5) + out = attention(q, k, v, locs, lens, 256, softmax_scale=scale, skip_softmax_threshold=1e-2) + assert out.shape == q.shape + + def test_monotonic_approximation_error(self): + """Larger threshold -> larger error vs dense (monotonic degradation).""" + q, k, v, locs, lens = self._make_inputs(seq_len=512) + scale = 1.0 / (64**0.5) + out_dense = attention(q, k, v, locs, lens, 512, softmax_scale=scale) + errors = [] + for threshold in [1e-4, 1e-2, 1e-1]: + out_skip = attention( + q, k, v, locs, lens, 512, softmax_scale=scale, skip_softmax_threshold=threshold + ) + errors.append((out_skip - out_dense).abs().mean().item()) + assert errors[0] <= errors[1] <= errors[2], f"Errors not monotonic: {errors}" + + def test_decode_single_token(self): + """Skip-softmax works for decode (single Q token per sequence).""" + batch = 2 + seq_lens_k = [64, 128] + num_heads, num_kv_heads, head_dim = 4, 2, 64 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(42) + q_flat = torch.randn(batch, num_heads, head_dim, device="cuda", dtype=torch.float16) + total_kv = sum(seq_lens_k) + k_flat = torch.randn(total_kv, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + v_flat = torch.randn(total_kv, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + + b_start_loc_q = torch.arange(batch, device="cuda", dtype=torch.int32) + b_seq_len_q = torch.ones(batch, device="cuda", dtype=torch.int32) + cumsum = [0] + for sl in seq_lens_k: + cumsum.append(cumsum[-1] + sl) + b_start_loc_k = torch.tensor(cumsum[:-1], device="cuda", dtype=torch.int32) + b_seq_len_k = torch.tensor(seq_lens_k, device="cuda", dtype=torch.int32) + + out_dense = attention( + q_flat, + k_flat, + v_flat, + b_start_loc_q, + b_seq_len_q, + 1, + is_causal=False, + softmax_scale=scale, + b_start_loc_k=b_start_loc_k, + b_seq_len_k=b_seq_len_k, + max_input_len_k=max(seq_lens_k), + ) + out_skip = attention( + q_flat, + k_flat, + v_flat, + b_start_loc_q, + b_seq_len_q, + 1, + is_causal=False, + softmax_scale=scale, + b_start_loc_k=b_start_loc_k, + b_seq_len_k=b_seq_len_k, + max_input_len_k=max(seq_lens_k), + skip_softmax_threshold=1e-3, + ) + torch.testing.assert_close(out_skip, out_dense, rtol=5e-2, atol=5e-2) + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestSkipSoftmaxVsPytorchRef: + """Cross-validate Triton skip-softmax against PyTorch flash_skip_softmax reference.""" + + def test_triton_matches_pytorch_reference(self): + """Triton skip-softmax output should be close to PyTorch reference with same threshold. + + The reference computes block-level BLASST masks using FlashSkipSoftmax.calculate_sparsity + and applies them to standard softmax attention. The Triton kernel fuses the same skip + logic into the online softmax inner loop. + """ + from modelopt.torch.sparsity.attention_sparsity.methods.flash_skip_softmax import ( + FlashSkipSoftmax, + ) + + batch, seq_len = 1, 256 + num_heads, num_kv_heads, head_dim = 4, 4, 64 # MHA for simplicity + scale = 1.0 / (head_dim**0.5) + threshold = 1e-2 + + torch.manual_seed(123) + q_4d = torch.randn(batch, num_heads, seq_len, head_dim, device="cuda", dtype=torch.float32) + k_4d = torch.randn( + batch, num_kv_heads, seq_len, head_dim, device="cuda", dtype=torch.float32 + ) + v_4d = torch.randn( + batch, num_kv_heads, seq_len, head_dim, device="cuda", dtype=torch.float32 + ) + + # --- PyTorch reference: eager attention with flash_skip_softmax --- + scores = torch.matmul(q_4d, k_4d.transpose(-2, -1)) * scale + # Apply causal mask + causal_mask = torch.triu(torch.ones(seq_len, seq_len, device="cuda"), diagonal=1).bool() + scores = scores.masked_fill(causal_mask[None, None, :, :], float("-inf")) + + # Apply BLASST mask via flash_skip_softmax + method = FlashSkipSoftmax( + method_config={ + "thresholds": {"prefill": [threshold]}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + sparse_mask, _ = method.calculate_sparsity(scores) + if sparse_mask is not None: + scores = scores.masked_fill(~sparse_mask, float("-inf")) + p = torch.softmax(scores, dim=-1) + ref_out = torch.matmul(p, v_4d) # [batch, heads, seq, dim] + + # --- Triton kernel with skip-softmax --- + total = batch * seq_len + q_flat = q_4d.permute(0, 2, 1, 3).reshape(total, num_heads, head_dim).contiguous() + k_flat = k_4d.permute(0, 2, 1, 3).reshape(total, num_kv_heads, head_dim).contiguous() + v_flat = v_4d.permute(0, 2, 1, 3).reshape(total, num_kv_heads, head_dim).contiguous() + locs = torch.arange(batch, device="cuda", dtype=torch.int32) * seq_len + lens = torch.full((batch,), seq_len, device="cuda", dtype=torch.int32) + + triton_out = attention( + q_flat, + k_flat, + v_flat, + locs, + lens, + seq_len, + is_causal=True, + softmax_scale=scale, + skip_softmax_threshold=threshold, + ) + triton_out_4d = triton_out.view(batch, seq_len, num_heads, head_dim).permute(0, 2, 1, 3) + + # Both outputs should be close — same algorithm, different implementations. + # Observed max abs error ~2e-3 (online vs standard softmax precision diffs). + torch.testing.assert_close(triton_out_4d, ref_out, rtol=5e-3, atol=5e-3) + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestSkipSoftmaxHFIntegration: + """HF integration for skip-softmax via mtsa.sparsify().""" + + def test_skip_softmax_via_sparsify(self, tiny_llama_dir): + """mtsa.sparsify() with triton_skip_softmax produces finite logits.""" + pytest.importorskip("transformers") + from transformers import AutoModelForCausalLM, AutoTokenizer + + import modelopt.torch.sparsity.attention_sparsity as mtsa + + tok = AutoTokenizer.from_pretrained(tiny_llama_dir) + if tok.pad_token_id is None: + tok.pad_token_id = tok.eos_token_id + ids = torch.randint(1, tok.vocab_size, (1, 64), device="cuda") + + # Dense baseline (triton backend, no skip) + model_dense = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + attn_implementation="modelopt_triton", + torch_dtype=torch.bfloat16, + device_map="cuda", + ) + model_dense.eval() + with torch.no_grad(): + logits_dense = model_dense(input_ids=ids).logits + del model_dense + + # Skip-softmax via mtsa.sparsify() + model_skip = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + torch_dtype=torch.bfloat16, + device_map="cuda", + ) + mtsa.sparsify(model_skip, mtsa.SKIP_SOFTMAX_TRITON_DEFAULT) + model_skip.eval() + with torch.no_grad(): + logits_skip = model_skip(input_ids=ids).logits + + assert not torch.isnan(logits_skip).any(), "NaN in skip-softmax logits" + assert not torch.isinf(logits_skip).any(), "Inf in skip-softmax logits" + # On short sequences (64 tokens), no tiles are skipped — output should match dense + torch.testing.assert_close(logits_skip, logits_dense, rtol=1e-3, atol=1e-3) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py index 7fd961a41..4eec5799a 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py @@ -346,3 +346,83 @@ def test_sparse_gradients_differ_from_dense(self): assert not torch.allclose(v_d.grad, v_s.grad, atol=1e-3), ( "dV same with and without sparsity" ) + + +# --------------------------------------------------------------------------- +# Integration +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestSparseNMIntegration: + """N:M sparse softmax integration tests.""" + + def test_sparse_disabled_matches_dense(self): + """sparsity_n=0 produces bit-identical output to default (dense).""" + seq_lens = [128, 128] + total = sum(seq_lens) + scale = 1.0 / (64**0.5) + + torch.manual_seed(99) + q, k, v = make_qkv(total, 4, 2, 64) + locs, lens = make_varlen_meta(seq_lens) + + out_dense = attention(q, k, v, locs, lens, 128, softmax_scale=scale) + out_n0 = attention(q, k, v, locs, lens, 128, softmax_scale=scale, sparsity_n=0) + assert torch.equal(out_dense, out_n0) + + def test_sparse_nm_via_sparsify(self, tiny_llama_dir): + """mtsa.sparsify() with N:M sparse softmax produces finite logits that differ from dense.""" + pytest.importorskip("transformers") + from transformers import AutoModelForCausalLM, AutoTokenizer + + import modelopt.torch.sparsity.attention_sparsity as mtsa + + tok = AutoTokenizer.from_pretrained(tiny_llama_dir) + if tok.pad_token_id is None: + tok.pad_token_id = tok.eos_token_id + ids = torch.randint(1, tok.vocab_size, (1, 64), device="cuda") + + # Dense baseline (triton backend, no sparsity) + model_dense = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + attn_implementation="modelopt_triton", + torch_dtype=torch.bfloat16, + device_map="cuda", + ) + model_dense.eval() + with torch.no_grad(): + logits_dense = model_dense(input_ids=ids).logits + del model_dense + + # Sparse via mtsa.sparsify() with dense_window_size=0 to force sparsity on all tiles + sparse_cfg = { + "sparse_cfg": { + "*attn*": { + "method": "triton_sparse_softmax", + "sparsity_n": 2, + "sparsity_m": 4, + "num_sink_tokens": 0, + "dense_window_size": 0, + "backend": "triton", + "enable": True, + }, + "default": {"enable": False}, + }, + } + model_sparse = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + torch_dtype=torch.bfloat16, + device_map="cuda", + ) + mtsa.sparsify(model_sparse, sparse_cfg) + assert model_sparse.config._attn_implementation == "modelopt_triton" + model_sparse.eval() + with torch.no_grad(): + logits_sparse = model_sparse(input_ids=ids).logits + + assert not torch.isnan(logits_sparse).any(), "NaN in sparse logits" + assert not torch.isinf(logits_sparse).any(), "Inf in sparse logits" + assert not torch.allclose(logits_sparse, logits_dense, atol=1e-2), ( + "Sparse logits identical to dense — sparsity may not be applied" + )