diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3993f4670..c647cf716 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,6 +7,7 @@ NVIDIA Model Optimizer Changelog - Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics. - Added iterator interface using CalibrationDataReader in ONNX quantization workflow. +- Add N:M sparse softmax support to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. - Enable PTQ workflow for the Step3.5-Flash MoE model with NVFP4 W4A4 + FP8 KV cache quantization. See `modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml `_ for more details. **Bug Fixes** diff --git a/examples/llm_sparsity/attention_sparsity/README.md b/examples/llm_sparsity/attention_sparsity/README.md index a55abaa6a..41519c466 100644 --- a/examples/llm_sparsity/attention_sparsity/README.md +++ b/examples/llm_sparsity/attention_sparsity/README.md @@ -1,6 +1,11 @@ # Attention Sparsity for HuggingFace Models -In this tutorial, we demonstrate how to use NVIDIA Model Optimizer to apply attention sparsity to HuggingFace models. Attention sparsity reduces computational cost by skipping near-zero attention scores during the softmax computation. Two attention backends are supported: +In this tutorial, we demonstrate how to use NVIDIA Model Optimizer to apply attention sparsity to HuggingFace models. Two sparsity methods are supported: + +- **Skip-softmax** (`flash_skip_softmax`): Skips attention tiles whose contribution is negligible, based on a threshold. Based on the [BLASST](https://arxiv.org/pdf/2512.12087) algorithm. +- **N:M sparse softmax** (`triton_sparse_softmax`): For every M consecutive key positions, keeps the top-N attention scores and sets the rest to -inf before softmax. + +Two attention backends are available: - **pytorch** (default): Patches `F.softmax` to apply skip-softmax sparsity (requires `attn_implementation="eager"`) - **triton**: Uses a fused Triton Flash Attention kernel with in-kernel sparsity (uses `attn_implementation="modelopt_triton"`) @@ -29,9 +34,9 @@ model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT) ## Configuration Options -Two pre-defined configurations are available: +### Skip-Softmax -### 1. Fixed Threshold (SKIP_SOFTMAX_DEFAULT) +#### 1. Fixed Threshold (SKIP_SOFTMAX_DEFAULT) Uses a fixed threshold value. Simple but may not be optimal for all sequence lengths. @@ -41,7 +46,7 @@ from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAU model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT) ``` -### 2. Calibrated Threshold (SKIP_SOFTMAX_CALIB) +#### 2. Calibrated Threshold (SKIP_SOFTMAX_CALIB) Uses RULER-based calibration to determine an optimal dynamic threshold that adapts to sequence length. Recommended for production use. @@ -51,6 +56,46 @@ from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_CALIB model = mtsa.sparsify(model, config=SKIP_SOFTMAX_CALIB) ``` +### N:M Sparse Softmax (SPARSE_SOFTMAX_DEFAULT) + +Applies N:M structured sparsity to attention scores using the Triton backend. For every M consecutive key positions, keeps only the top-N scores and sets the rest to -inf. Supports M=4 (N=1,2,3) and M=8 (N=1..7). Attention sinks and a local dense window can be configured to preserve important positions. + +```python +from modelopt.torch.sparsity.attention_sparsity.config import SPARSE_SOFTMAX_DEFAULT + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.1-8B", + torch_dtype=torch.bfloat16, + device_map="cuda", +) + +model = mtsa.sparsify(model, config=SPARSE_SOFTMAX_DEFAULT) +``` + +Custom N:M configuration: + +```python +sparse_cfg = { + "sparse_cfg": { + "*attn*": { + "method": "triton_sparse_softmax", + "sparsity_n": 2, # Keep top-2 of every 4 + "sparsity_m": 4, # Group size + "num_sink_tokens": 4, # Keep first 4 tokens dense (attention sinks) + "dense_window_size": 128, # Keep tokens within distance 128 dense + "backend": "triton", + "enable": True, + }, + "default": {"enable": False}, + }, +} + +model = mtsa.sparsify(model, config=sparse_cfg) +``` + +> [!Note] +> N:M sparse softmax requires the Triton backend (`backend="triton"`). The `attn_implementation` is automatically set to `"modelopt_triton"` by `mtsa.sparsify()`. N:M sparsity is applied during prefill only — decode tokens are not sparsified. + ## Prerequisites ### Local Installation @@ -104,8 +149,8 @@ The calibration process: | Argument | Default | Description | |----------|---------|-------------| | `--pyt_ckpt_path` | Required | HuggingFace model path or name | -| `--sparse_attn` | `skip_softmax` | Configuration: `skip_softmax` or `skip_softmax_calib` | -| `--backend` | `pytorch` | Backend: `pytorch` (only supported backend) | +| `--sparse_attn` | `skip_softmax` | Configuration: `skip_softmax`, `skip_softmax_calib`, or `sparse_softmax` | +| `--backend` | `pytorch` | Backend: `pytorch` (skip-softmax) or `triton` (N:M sparse softmax) | | `--seq_len` | `2048` | Maximum sequence length for input prompts | | `--export_dir` | `None` | Directory to export the sparsified model | @@ -166,3 +211,4 @@ model = mtsa.sparsify(model, config=custom_config) - [Model Optimizer Documentation](https://nvidia.github.io/Model-Optimizer/) - [RULER: What's the Real Context Size of Your Long-Context Language Models?](https://github.com/NVIDIA/RULER) +- [BLASST: Block-Level Adaptive Structured Sparse Training](https://arxiv.org/pdf/2512.12087) — skip-softmax algorithm diff --git a/examples/llm_sparsity/attention_sparsity/hf_sa.py b/examples/llm_sparsity/attention_sparsity/hf_sa.py index 8115d4aaf..bdf2bb5ee 100644 --- a/examples/llm_sparsity/attention_sparsity/hf_sa.py +++ b/examples/llm_sparsity/attention_sparsity/hf_sa.py @@ -31,6 +31,7 @@ from modelopt.torch.sparsity.attention_sparsity.config import ( SKIP_SOFTMAX_CALIB, SKIP_SOFTMAX_DEFAULT, + SPARSE_SOFTMAX_DEFAULT, ) from modelopt.torch.utils.memory_monitor import launch_memory_monitor @@ -43,6 +44,7 @@ SPARSE_ATTN_CFG_CHOICES = { "skip_softmax": SKIP_SOFTMAX_DEFAULT, "skip_softmax_calib": SKIP_SOFTMAX_CALIB, + "sparse_softmax": SPARSE_SOFTMAX_DEFAULT, } @@ -168,9 +170,10 @@ def main(args): # Apply CLI overrides to sparse_cfg sparse_cfg = sparse_config.get("sparse_cfg", {}) - for layer_cfg in sparse_cfg.values(): - if isinstance(layer_cfg, dict) and "method" in layer_cfg: - layer_cfg["backend"] = args.backend + if args.backend is not None: + for layer_cfg in sparse_cfg.values(): + if isinstance(layer_cfg, dict) and "method" in layer_cfg: + layer_cfg["backend"] = args.backend if args.target_sparse_ratio is not None: calib = sparse_cfg.setdefault("calibration", {}) assert isinstance(calib, dict) @@ -240,9 +243,10 @@ def main(args): parser.add_argument( "--backend", type=str, - default="pytorch", + default=None, choices=["pytorch", "triton"], - help="Backend for sparse attention (default: pytorch). 'triton' uses the fused Triton kernel.", + help="Backend for sparse attention. Overrides the config default if set. " + "'triton' uses the fused Triton kernel.", ) # Sequence length arguments diff --git a/modelopt/torch/kernels/hf_triton_attention.py b/modelopt/torch/kernels/hf_triton_attention.py index 5f10df250..afe4852ea 100644 --- a/modelopt/torch/kernels/hf_triton_attention.py +++ b/modelopt/torch/kernels/hf_triton_attention.py @@ -105,6 +105,17 @@ def triton_attention_forward( kw["b_seq_len_k"] = torch.full((batch,), seq_k, device=device, dtype=torch.int32) kw["max_input_len_k"] = seq_k + # N:M sparse softmax — prefill only (decode should not sparsify KV) + if not is_decode and getattr(module, "_apply_sparse_nm", False): + # _sparse_method_instance is set by SparseAttentionModule._init_sparse_method() + # in modelopt/torch/sparsity/attention_sparsity/sparse_attention.py + method = getattr(module, "_sparse_method_instance", None) + if method is not None: + kw["sparsity_n"] = getattr(method, "sparsity_n", 2) + kw["sparsity_m"] = getattr(method, "sparsity_m", 4) + kw["num_sink_tokens"] = getattr(method, "num_sink_tokens", 0) + kw["dense_window_size"] = getattr(method, "dense_window_size", 64) + o = attention(q, k, v, **kw) attn_output = o.view(batch, seq_len, num_heads, head_dim) diff --git a/modelopt/torch/kernels/triton_fa.py b/modelopt/torch/kernels/triton_fa.py index b9184788b..69602d8d6 100644 --- a/modelopt/torch/kernels/triton_fa.py +++ b/modelopt/torch/kernels/triton_fa.py @@ -45,6 +45,145 @@ _FWD_CONFIGS = [triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=1, num_warps=4)] +# --------------------------------------------------------------------------- +# N:M sparse softmax helpers +# --------------------------------------------------------------------------- +@triton.jit +def _sparse_nm_masks_m4(x0, x1, x2, x3, N: tl.constexpr): + """Top-N of 4 selection via pure boolean logic (6 comparisons, no int casts). + + Uses ``>=`` so that ties are broken by index (lower index wins). + Guarantees exactly N masks are True for any input including all-equal. + + Boolean formulas for "at least K of 3 wins": + K=3 (N=1): AND of all — must beat all 3 others + K=2 (N=2): majority — must beat at least 2 (sorting network) + K=1 (N=3): OR of all — must beat at least 1 + """ + c01 = x0 >= x1 + c02 = x0 >= x2 + c03 = x0 >= x3 + c12 = x1 >= x2 + c13 = x1 >= x3 + c23 = x2 >= x3 + + nc01 = ~c01 + nc02 = ~c02 + nc03 = ~c03 + nc12 = ~c12 + nc13 = ~c13 + nc23 = ~c23 + + if N == 1: + # Keep max only: must beat all 3 + m0 = c01 & c02 & c03 + m1 = nc01 & c12 & c13 + m2 = nc02 & nc12 & c23 + m3 = nc03 & nc13 & nc23 + elif N == 2: + # Majority vote: must beat at least 2 of 3 + m0 = (c01 & c02) | (c01 & c03) | (c02 & c03) + m1 = (nc01 & c12) | (nc01 & c13) | (c12 & c13) + m2 = (nc02 & nc12) | (nc02 & c23) | (nc12 & c23) + m3 = (nc03 & nc13) | (nc03 & nc23) | (nc13 & nc23) + elif N == 3: + # Keep all but min: must beat at least 1 + m0 = c01 | c02 | c03 + m1 = nc01 | c12 | c13 + m2 = nc02 | nc12 | c23 + m3 = nc03 | nc13 | nc23 + else: + tl.static_assert(False, "N must be 1, 2, or 3 for M=4") + + return m0, m1, m2, m3 + + +@triton.jit +def _apply_sparse_nm_to_qk_tile( + qk, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + SPARSITY_N: tl.constexpr, + SPARSITY_M: tl.constexpr, +): + """Apply N:M sparse softmax to a QK score tile. + + For every ``SPARSITY_M`` consecutive elements along the N (key) dimension, + keeps the top ``SPARSITY_N`` values and sets the rest to ``-inf``. + ``BLOCK_N`` must be divisible by ``SPARSITY_M``. + + For M=4, exactly N values are retained (ties broken by position). + For M=8, a threshold-based approach (``tl.sort``) may retain more + than N values when ties straddle the threshold boundary. + """ + tl.static_assert(SPARSITY_M == 4 or SPARSITY_M == 8, "SPARSITY_M must be 4 or 8") # noqa: PLR1714 + MASK_VAL: tl.constexpr = float("-inf") + + if SPARSITY_M == 4: + tl.static_assert(BLOCK_N % 4 == 0, "BLOCK_N must be divisible by 4") + reshaped = tl.reshape(qk, (BLOCK_M, BLOCK_N // 4, 4)) + cols = tl.arange(0, 4)[None, None, :] + x0 = tl.sum(tl.where(cols == 0, reshaped, 0.0), axis=2) + x1 = tl.sum(tl.where(cols == 1, reshaped, 0.0), axis=2) + x2 = tl.sum(tl.where(cols == 2, reshaped, 0.0), axis=2) + x3 = tl.sum(tl.where(cols == 3, reshaped, 0.0), axis=2) + + m0, m1, m2, m3 = _sparse_nm_masks_m4(x0, x1, x2, x3, SPARSITY_N) + + out = tl.full((BLOCK_M, BLOCK_N // 4, 4), 0.0, dtype=qk.dtype) + out = tl.where(cols == 0, tl.expand_dims(tl.where(m0, x0, MASK_VAL), 2), out) + out = tl.where(cols == 1, tl.expand_dims(tl.where(m1, x1, MASK_VAL), 2), out) + out = tl.where(cols == 2, tl.expand_dims(tl.where(m2, x2, MASK_VAL), 2), out) + out = tl.where(cols == 3, tl.expand_dims(tl.where(m3, x3, MASK_VAL), 2), out) + return tl.reshape(out, (BLOCK_M, BLOCK_N)) + + else: # SPARSITY_M == 8 + tl.static_assert(BLOCK_N % 8 == 0, "BLOCK_N must be divisible by 8") + reshaped = tl.reshape(qk, (BLOCK_M, BLOCK_N // 8, 8)) + + # Sort each group of 8 ascending; N-th largest is at index (8 - N) + sorted_vals = tl.sort(reshaped, dim=2) + KTH_IDX: tl.constexpr = SPARSITY_M - SPARSITY_N # index of N-th largest in ascending order + + # Extract the threshold value at KTH_IDX via masked sum + # Use 0.0 as fill (not -inf) so sum equals just the KTH element + cols = tl.arange(0, 8)[None, None, :] + threshold = tl.sum(tl.where(cols == KTH_IDX, sorted_vals, 0.0), axis=2) + + # Mask: keep elements >= threshold (may keep >N on ties — acceptable) + mask = reshaped >= tl.expand_dims(threshold, 2) + return tl.reshape(tl.where(mask, reshaped, MASK_VAL), (BLOCK_M, BLOCK_N)) + + +# --------------------------------------------------------------------------- +# Sink/window dense-region check +# --------------------------------------------------------------------------- +@triton.jit +def _is_dense_region( + kv_start, + tile_q, + seq_len_q, + seq_len_kv, + BLOCK_M: tl.constexpr, + NUM_SINK_TOKENS: tl.constexpr, + DENSE_WINDOW_SIZE: tl.constexpr, +): + """Check if a KV tile falls in a dense region (sink tokens or local window). + + Uses absolute token positions so the result is BLOCK_N-independent, + ensuring forward and backward (which may use different BLOCK_N) agree. + + Returns: + True if the tile should be kept dense (skip N:M sparsification). + """ + is_sink = kv_start < NUM_SINK_TOKENS + causal_offset = seq_len_kv - seq_len_q + q_abs_pos = tile_q * BLOCK_M + causal_offset + token_distance = q_abs_pos - kv_start + is_local = (token_distance >= 0) and (token_distance < DENSE_WINDOW_SIZE) + return is_sink or is_local + + # --------------------------------------------------------------------------- # Masking helper # --------------------------------------------------------------------------- @@ -105,6 +244,10 @@ def _attn_fwd( IS_CAUSAL: tl.constexpr, # Whether to apply causal mask HEAD_DIM: tl.constexpr, # Actual head dimension (for d_mask) STORE_LSE: tl.constexpr, # Whether to save LSE for backward pass + SPARSITY_N: tl.constexpr = 0, # N:M sparsity — keep top-N of every M elements (0 = disabled) + SPARSITY_M: tl.constexpr = 4, # N:M sparsity — group size (4 or 8) + NUM_SINK_TOKENS: tl.constexpr = 0, # KV positions before this are kept dense (attention sinks) + DENSE_WINDOW_SIZE: tl.constexpr = 64, # Tokens near diagonal kept dense (absolute, BLOCK_N-independent) ): # --- Grid: (batch, num_q_heads, num_q_tiles) --- # Example: batch=2, num_q_heads=32, seq_len=256, BLOCK_M=128 @@ -162,6 +305,21 @@ def _attn_fwd( scores = tl.dot(q, k) * qk_scale scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL) + # --- Optional N:M sparse softmax --- + if SPARSITY_N > 0: + if not _is_dense_region( + kv_start, + tile_q, + seq_len_q, + seq_len_kv, + BLOCK_M, + NUM_SINK_TOKENS, + DENSE_WINDOW_SIZE, + ): + scores = _apply_sparse_nm_to_qk_tile( + scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M + ) + # --- Online softmax update --- # 1. Update running max m_new = tl.maximum(row_max, tl.max(scores, 1)) @@ -278,6 +436,10 @@ def _attn_bwd_dq( BLOCK_N: tl.constexpr, IS_CAUSAL: tl.constexpr, HEAD_DIM: tl.constexpr, + SPARSITY_N: tl.constexpr = 0, + SPARSITY_M: tl.constexpr = 4, + NUM_SINK_TOKENS: tl.constexpr = 0, + DENSE_WINDOW_SIZE: tl.constexpr = 64, ): """Phase 3 of backward: compute dQ for one Q tile, looping over KV tiles. @@ -343,6 +505,22 @@ def _attn_bwd_dq( # Recompute attention: S = Q @ K^T, P = exp2(S - LSE) scores = tl.dot(q, kT) * qk_scale scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL) + + # Re-apply N:M sparse softmax to match forward pass + if SPARSITY_N > 0: + if not _is_dense_region( + kv_start, + tile_q, + seq_len_q, + seq_len_kv, + BLOCK_M, + NUM_SINK_TOKENS, + DENSE_WINDOW_SIZE, + ): + scores = _apply_sparse_nm_to_qk_tile( + scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M + ) + p = tl.math.exp2(scores - lse[:, None]) # dP = dO @ V^T, dS = P * (dP - delta), dQ += dS @ K @@ -392,6 +570,10 @@ def _attn_bwd_dkdv( BLOCK_N: tl.constexpr, IS_CAUSAL: tl.constexpr, HEAD_DIM: tl.constexpr, + SPARSITY_N: tl.constexpr = 0, + SPARSITY_M: tl.constexpr = 4, + NUM_SINK_TOKENS: tl.constexpr = 0, + DENSE_WINDOW_SIZE: tl.constexpr = 64, ): """Phase 2 of backward: compute dK, dV for one KV tile. @@ -465,6 +647,22 @@ def _attn_bwd_dkdv( # Recompute attention: S = Q @ K^T, P = exp2(S - LSE) scores = tl.dot(q_tile, kT) * qk_scale scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL) + + # Re-apply N:M sparse softmax to match forward pass + if SPARSITY_N > 0: + if not _is_dense_region( + kv_start, + qi, + seq_len_q, + seq_len_kv, + BLOCK_M, + NUM_SINK_TOKENS, + DENSE_WINDOW_SIZE, + ): + scores = _apply_sparse_nm_to_qk_tile( + scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M + ) + p = tl.math.exp2(scores - lse[:, None]) # dV += P^T @ dO @@ -498,6 +696,10 @@ def forward( b_start_loc_k, b_seq_len_k, max_input_len_k, + sparsity_n, + sparsity_m, + num_sink_tokens, + dense_window_size, ): HEAD_DIM = q.shape[2] num_q_heads = q.shape[1] @@ -552,6 +754,10 @@ def grid(META): IS_CAUSAL=is_causal, HEAD_DIM=HEAD_DIM, STORE_LSE=True, + SPARSITY_N=sparsity_n, + SPARSITY_M=sparsity_m, + NUM_SINK_TOKENS=num_sink_tokens, + DENSE_WINDOW_SIZE=dense_window_size, # BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune ) @@ -566,6 +772,10 @@ def grid(META): ctx.num_q_heads = num_q_heads ctx.num_kv_heads = num_kv_heads ctx.batch = batch + ctx.sparsity_n = sparsity_n + ctx.sparsity_m = sparsity_m + ctx.num_sink_tokens = num_sink_tokens + ctx.dense_window_size = dense_window_size return o @staticmethod @@ -640,6 +850,10 @@ def backward(ctx, grad_output): BLOCK_N=BLOCK, IS_CAUSAL=ctx.is_causal, HEAD_DIM=HEAD_DIM, + SPARSITY_N=ctx.sparsity_n, + SPARSITY_M=ctx.sparsity_m, + NUM_SINK_TOKENS=ctx.num_sink_tokens, + DENSE_WINDOW_SIZE=ctx.dense_window_size, num_warps=num_warps, num_stages=1, ) @@ -659,11 +873,15 @@ def backward(ctx, grad_output): BLOCK_N=BLOCK, IS_CAUSAL=ctx.is_causal, HEAD_DIM=HEAD_DIM, + SPARSITY_N=ctx.sparsity_n, + SPARSITY_M=ctx.sparsity_m, + NUM_SINK_TOKENS=ctx.num_sink_tokens, + DENSE_WINDOW_SIZE=ctx.dense_window_size, num_warps=num_warps, num_stages=1, ) - return dq, dk, dv, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None def attention( @@ -678,8 +896,13 @@ def attention( b_start_loc_k: torch.Tensor | None = None, b_seq_len_k: torch.Tensor | None = None, max_input_len_k: int | None = None, + *, + sparsity_n: int = 0, + sparsity_m: int = 4, + num_sink_tokens: int = 0, + dense_window_size: int = 64, ) -> torch.Tensor: - """Variable-length flash attention with GQA and autograd support. + """Variable-length flash attention with GQA, autograd, and optional N:M sparse softmax. Args: q: [total_q_tokens, num_q_heads, head_dim] @@ -693,6 +916,16 @@ def attention( b_start_loc_k: [batch] start offset for K/V (None = same as Q). b_seq_len_k: [batch] length for K/V (None = same as Q). max_input_len_k: Maximum K/V sequence length (None = same as Q). + sparsity_n: N:M sparsity — keep top-N of every M attention scores + along the key dimension. Set to 0 to disable. Examples: + ``sparsity_n=2, sparsity_m=4`` for 2:4 sparsity; + ``sparsity_n=4, sparsity_m=8`` for 4:8 sparsity. + sparsity_m: N:M sparsity — group size (4 or 8). + num_sink_tokens: KV positions before this token index are kept dense + (attention sinks). Absolute token count, BLOCK_N-independent. + dense_window_size: Tokens near the query diagonal kept dense (local + attention window). Absolute token count, BLOCK_N-independent. + Default 64 (one reference block). Returns: Output tensor [total_q_tokens, num_q_heads, head_dim]. @@ -710,6 +943,10 @@ def attention( b_start_loc_k, b_seq_len_k, max_input_len_k, + sparsity_n, + sparsity_m, + num_sink_tokens, + dense_window_size, ) diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 4baf5bbe6..7f9e3d76e 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -18,7 +18,7 @@ from collections.abc import Callable from typing import Any -from pydantic import Field, field_validator +from pydantic import Field, field_validator, model_validator from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField @@ -96,6 +96,39 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): ), ) + sparsity_n: int = ModeloptField( + default=2, + title="N in N:M sparsity.", + description=( + "Keep top-N of every M attention scores. Only used by triton_sparse_softmax. " + "Set to 0 to disable sparsity." + ), + ) + + sparsity_m: int = ModeloptField( + default=4, + title="M in N:M sparsity.", + description="Group size for N:M sparsity (4 or 8). Only used by triton_sparse_softmax.", + ) + + num_sink_tokens: int = ModeloptField( + default=0, + title="Number of sink tokens.", + description=( + "KV positions before this index are kept dense (attention sinks). " + "Absolute token count. Only used by triton_sparse_softmax." + ), + ) + + dense_window_size: int = ModeloptField( + default=64, + title="Dense window size in tokens.", + description=( + "Tokens near the query diagonal kept dense (local attention window). " + "Absolute token count. Default 64. Only used by triton_sparse_softmax." + ), + ) + @field_validator("method") @classmethod def validate_method(cls, v): @@ -116,6 +149,38 @@ def validate_backend(cls, v): ) return v + @field_validator("sparsity_m") + @classmethod + def validate_sparsity_m(cls, v): + """Validate sparsity_m is 4 or 8.""" + if v not in (4, 8): + raise ValueError(f"sparsity_m must be 4 or 8, got {v}") + return v + + @field_validator("sparsity_n") + @classmethod + def validate_sparsity_n(cls, v): + """Validate sparsity_n is non-negative.""" + if v < 0: + raise ValueError(f"sparsity_n must be >= 0, got {v}") + return v + + @field_validator("num_sink_tokens") + @classmethod + def validate_num_sink_tokens(cls, v): + """Validate num_sink_tokens is non-negative.""" + if v < 0: + raise ValueError(f"num_sink_tokens must be >= 0, got {v}") + return v + + @field_validator("dense_window_size") + @classmethod + def validate_dense_window_size(cls, v): + """Validate dense_window_size is non-negative.""" + if v < 0: + raise ValueError(f"dense_window_size must be >= 0, got {v}") + return v + @field_validator("br", "bc") @classmethod def validate_block_size(cls, v): @@ -160,6 +225,18 @@ def validate_thresholds(cls, v): ) return v + @model_validator(mode="after") + def validate_sparsity_n_vs_m(self): + """Validate sparsity_n is within the supported range for the given sparsity_m.""" + if self.sparsity_n > 0: + max_n = 3 if self.sparsity_m == 4 else self.sparsity_m - 1 + if self.sparsity_n > max_n: + raise ValueError( + f"sparsity_n={self.sparsity_n} exceeds max for sparsity_m={self.sparsity_m}. " + f"Valid range: 1..{max_n}" + ) + return self + class CalibrationConfig(ModeloptBaseConfig): """Configuration for automatic threshold calibration using RULER dataset. @@ -434,9 +511,27 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): } +# Default N:M sparse softmax configuration +SPARSE_SOFTMAX_DEFAULT = { + "sparse_cfg": { + "*attn*": { + "method": "triton_sparse_softmax", + "sparsity_n": 2, + "sparsity_m": 4, + "num_sink_tokens": 0, + "dense_window_size": 64, + "backend": "triton", + "enable": True, + }, + "default": {"enable": False}, + }, +} + + __all__ = [ "SKIP_SOFTMAX_CALIB", "SKIP_SOFTMAX_DEFAULT", + "SPARSE_SOFTMAX_DEFAULT", "CalibrationConfig", "FlashSkipSoftmaxConfig", "SparseAttentionAttributeConfig", diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py index 8a109fda7..1bd9a547d 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py @@ -24,4 +24,4 @@ ] # Import method implementations to trigger registration -from . import flash_skip_softmax +from . import flash_skip_softmax, triton_sparse_softmax diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py index 3f3e78db6..c6a8638b7 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -38,13 +38,16 @@ def __init__(self): # Target sparsity ratio per phase: {"prefill": 0.5, "decode": 0.5} self.target_sparse_ratio: dict[str, float] | None = None - @abstractmethod def calculate_sparsity( self, attention_scores: torch.Tensor, ) -> tuple[torch.Tensor, dict]: """Calculate sparsity mask and statistics without applying. + Default: no-op (keep all). Override for methods that compute masks + outside the kernel (e.g. pytorch-backend softmax patching). + Kernel-fused methods (Triton backend) can use this default. + Args: attention_scores: Pre-softmax attention scores [batch, heads, seq_q, seq_k] @@ -53,8 +56,8 @@ def calculate_sparsity( - sparse_mask: Boolean tensor indicating which elements to keep - stats_dict: Dictionary with sparsity statistics """ + return torch.ones_like(attention_scores, dtype=torch.bool), {} - @abstractmethod def apply_sparsity( self, attention_scores: torch.Tensor, @@ -62,6 +65,10 @@ def apply_sparsity( ) -> torch.Tensor: """Apply sparsity mask to attention scores. + Default: raises NotImplementedError. Override for methods that apply + masks outside the kernel. Kernel-fused methods (Triton backend) + don't need this — sparsity is applied inside the kernel. + Args: attention_scores: Pre-softmax attention scores [batch, heads, seq_q, seq_k] sparse_mask: Optional pre-computed mask. If None, calculates internally. @@ -69,6 +76,10 @@ def apply_sparsity( Returns: Masked attention scores with sparse elements set to -inf """ + raise NotImplementedError( + f"{type(self).__name__} does not implement apply_sparsity. " + "Sparsity may be fused into the kernel (Triton backend)." + ) def get_sparse_context(self, module: torch.nn.Module): """Return a context manager that activates this method's sparsity during forward. diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py new file mode 100644 index 000000000..c0639ed0b --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""N:M sparse softmax method for attention scores via Triton kernel.""" + +from contextlib import contextmanager + +from .registry import SparseAttentionMethod, register_sparse_method + + +@register_sparse_method("triton_sparse_softmax") +class TritonSparseSoftmaxMethod(SparseAttentionMethod): + """N:M sparse softmax applied to attention scores via Triton kernel. + + Sparsity is applied inside the fused Triton flash attention kernel, + not as a separate pre/post-processing step. For every M consecutive + K positions, the top-N attention scores are kept; the other M-N are + set to -inf before softmax. + + Config params: + sparsity_n: Keep top-N of every M attention scores (0 to disable). + sparsity_m: Group size (4 or 8). + num_sink_tokens: KV positions before this index kept dense (attention sinks). + dense_window_size: Tokens near diagonal kept dense (absolute token count). + """ + + def __init__(self, method_config=None): + """Initialize with N:M sparsity parameters from config.""" + super().__init__() + method_config = method_config or {} + self.sparsity_n = method_config.get("sparsity_n", 2) + self.sparsity_m = method_config.get("sparsity_m", 4) + self.num_sink_tokens = method_config.get("num_sink_tokens", 0) + self.dense_window_size = method_config.get("dense_window_size", 64) + + @property + def name(self) -> str: + """Method name identifier.""" + return "triton_sparse_softmax" + + # calculate_sparsity and apply_sparsity use base class defaults + # (no-op mask and NotImplementedError) — sparsity is fused into the Triton kernel. + + def get_sparse_context(self, module): + """Return context manager that activates N:M sparse softmax during forward.""" + + @contextmanager + def _sparse_nm_context(): + module._apply_sparse_nm = True + try: + yield + finally: + module._apply_sparse_nm = False + + return _sparse_nm_context() diff --git a/tests/gpu/torch/sparsity/attention_sparsity/conftest.py b/tests/gpu/torch/sparsity/attention_sparsity/conftest.py new file mode 100644 index 000000000..fa4f61771 --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/conftest.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared fixtures and helpers for Triton flash attention tests.""" + +import pytest +import torch +import torch.nn.functional as F + + +def make_qkv(total, num_heads, num_kv_heads, head_dim, device="cuda", dtype=torch.float16): + """Create packed Q, K, V tensors.""" + q = torch.randn(total, num_heads, head_dim, device=device, dtype=dtype) + k = torch.randn(total, num_kv_heads, head_dim, device=device, dtype=dtype) + v = torch.randn(total, num_kv_heads, head_dim, device=device, dtype=dtype) + return q, k, v + + +def make_varlen_meta(seq_lens, device="cuda"): + """Create b_start_loc and b_seq_len from a list of sequence lengths.""" + b_seq_len = torch.tensor(seq_lens, device=device, dtype=torch.int32) + b_start_loc = torch.zeros(len(seq_lens), device=device, dtype=torch.int32) + b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], dim=0) + return b_start_loc, b_seq_len + + +def sdpa_reference(q, k, v, b_start_loc, b_seq_len, is_causal=True): + """SDPA reference. Supports GQA. Returns [total_tokens, num_heads, dim].""" + batch = b_seq_len.shape[0] + num_q, num_kv = q.shape[1], k.shape[1] + parts = [] + for b in range(batch): + s, n = int(b_start_loc[b].item()), int(b_seq_len[b].item()) + qb = q[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) + kb = k[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) + vb = v[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) + if num_q != num_kv: + r = num_q // num_kv + kb = kb.repeat_interleave(r, dim=1) + vb = vb.repeat_interleave(r, dim=1) + ob = F.scaled_dot_product_attention(qb, kb, vb, is_causal=is_causal) + parts.append(ob.permute(0, 2, 1, 3).squeeze(0)) + return torch.cat(parts, dim=0) + + +@pytest.fixture(scope="module") +def tiny_llama_dir(tmp_path_factory): + """Tiny Llama: 2 layers, 64 hidden, 4 q-heads, 2 kv-heads, head_dim=16.""" + from _test_utils.torch.transformers_models import create_tiny_llama_dir + + return create_tiny_llama_dir( + tmp_path_factory.mktemp("tiny_llama"), + with_tokenizer=True, + num_hidden_layers=2, + hidden_size=64, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=64, + max_position_embeddings=64, + ) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py index c86a4131e..5ee2696da 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""GPU tests for Triton flash attention kernel.""" +"""GPU tests for Triton flash attention kernel (dense path).""" import pytest import torch import torch.nn.functional as F +from conftest import make_qkv, make_varlen_meta, sdpa_reference pytestmark = [ pytest.mark.filterwarnings("ignore::UserWarning"), @@ -34,45 +35,14 @@ register_triton_attention() -def _sdpa_reference(q, k, v, b_start_loc, b_seq_len): - """SDPA causal reference. Supports GQA. Returns [total_tokens, num_heads, dim].""" - batch = b_seq_len.shape[0] - num_q, num_kv = q.shape[1], k.shape[1] - parts = [] - for b in range(batch): - s, n = int(b_start_loc[b].item()), int(b_seq_len[b].item()) - qb = q[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) - kb = k[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) - vb = v[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) - if num_q != num_kv: - r = num_q // num_kv - kb = kb.repeat_interleave(r, dim=1) - vb = vb.repeat_interleave(r, dim=1) - ob = F.scaled_dot_product_attention(qb, kb, vb, is_causal=True) - parts.append(ob.permute(0, 2, 1, 3).squeeze(0)) - return torch.cat(parts, dim=0) - - -@pytest.fixture(scope="module") -def tiny_llama_dir(tmp_path_factory): - """Tiny Llama: 2 layers, 64 hidden, 4 q-heads, 2 kv-heads, head_dim=16.""" - from _test_utils.torch.transformers_models import create_tiny_llama_dir - - return create_tiny_llama_dir( - tmp_path_factory.mktemp("tiny_llama"), - with_tokenizer=True, - num_hidden_layers=2, - hidden_size=64, - num_attention_heads=4, - num_key_value_heads=2, - intermediate_size=64, - max_position_embeddings=64, - ) +# --------------------------------------------------------------------------- +# Forward correctness +# --------------------------------------------------------------------------- @pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") -class TestTritonFaVsSdpa: - """Triton flash attention matches PyTorch SDPA for prefill and decode.""" +class TestForward: + """Forward pass correctness for dense attention.""" @pytest.mark.parametrize( ("dtype", "num_heads", "num_kv_heads", "head_dim"), @@ -84,42 +54,27 @@ class TestTritonFaVsSdpa: ids=["fp32_mha", "fp16_gqa", "bf16_gqa_hdim128"], ) def test_prefill_matches_sdpa(self, dtype, num_heads, num_kv_heads, head_dim): - """Prefill matches SDPA.""" + """Dense prefill matches SDPA.""" seq_lens = [8, 12] total = sum(seq_lens) scale = 1.0 / (head_dim**0.5) torch.manual_seed(123) - q = torch.randn(total, num_heads, head_dim, device="cuda", dtype=dtype) - k = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=dtype) - v = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=dtype) - locs = torch.tensor([0, seq_lens[0]], device="cuda", dtype=torch.int32) - lens = torch.tensor(seq_lens, device="cuda", dtype=torch.int32) - - o = attention( - q, - k, - v, - b_start_loc=locs, - b_seq_len=lens, - max_input_len=max(seq_lens), - is_causal=True, - softmax_scale=scale, - ) - torch.testing.assert_close(o, _sdpa_reference(q, k, v, locs, lens), rtol=1e-3, atol=1e-3) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim, dtype=dtype) + locs, lens = make_varlen_meta(seq_lens) + + o = attention(q, k, v, locs, lens, max(seq_lens), softmax_scale=scale) + torch.testing.assert_close(o, sdpa_reference(q, k, v, locs, lens), rtol=1e-3, atol=1e-3) def test_decode_matches_sdpa(self): - """Decode matches SDPA.""" + """Dense decode matches SDPA.""" batch = 2 - seq_lens_k = [5, 9] # KV lengths (context + current token) + seq_lens_k = [5, 9] num_heads, num_kv_heads, head_dim = 4, 2, 32 scale = 1.0 / (head_dim**0.5) torch.manual_seed(103) - # Q: one token per batch element -> flat [batch, num_heads, head_dim] q_flat = torch.randn(batch, num_heads, head_dim, device="cuda", dtype=torch.float32) - - # K/V: variable-length, packed into flat tensors total_kv = sum(seq_lens_k) k_flat = torch.randn(total_kv, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) v_flat = torch.randn(total_kv, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) @@ -136,9 +91,9 @@ def test_decode_matches_sdpa(self): q_flat, k_flat, v_flat, - b_start_loc=b_start_loc_q, - b_seq_len=b_seq_len_q, - max_input_len=1, + b_start_loc_q, + b_seq_len_q, + 1, is_causal=False, softmax_scale=scale, b_start_loc_k=b_start_loc_k, @@ -149,7 +104,7 @@ def test_decode_matches_sdpa(self): for i in range(batch): sl = seq_lens_k[i] s = cumsum[i] - qb = q_flat[i : i + 1].unsqueeze(2) # [1, heads, 1, dim] + qb = q_flat[i : i + 1].unsqueeze(2) kb = k_flat[s : s + sl].unsqueeze(0).permute(0, 2, 1, 3) vb = v_flat[s : s + sl].unsqueeze(0).permute(0, 2, 1, 3) kb = kb.repeat_interleave(num_heads // num_kv_heads, dim=1) @@ -157,10 +112,152 @@ def test_decode_matches_sdpa(self): ref = F.scaled_dot_product_attention(qb, kb, vb, is_causal=False).squeeze(2) torch.testing.assert_close(out[i : i + 1], ref, rtol=1e-3, atol=1e-3) + def test_sparse_disabled_matches_dense(self): + """sparsity_n=0 produces bit-identical output to default (dense).""" + seq_lens = [128, 128] + total = sum(seq_lens) + scale = 1.0 / (64**0.5) + + torch.manual_seed(99) + q, k, v = make_qkv(total, 4, 2, 64) + locs, lens = make_varlen_meta(seq_lens) + + out_dense = attention(q, k, v, locs, lens, 128, softmax_scale=scale) + out_n0 = attention(q, k, v, locs, lens, 128, softmax_scale=scale, sparsity_n=0) + assert torch.equal(out_dense, out_n0) + + +# --------------------------------------------------------------------------- +# Backward correctness (dense) +# --------------------------------------------------------------------------- + @pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") -class TestSparseAttentionIntegration: - """HF model + mtsa.sparsify integration.""" +class TestBackward: + """Backward pass gradient correctness for dense attention.""" + + def _sdpa_backward_ref(self, q, k, v, scale, is_causal=True): + """Run SDPA forward+backward, return gradients.""" + q_ref = q.clone().unsqueeze(0).permute(0, 2, 1, 3).requires_grad_(True) + k_ref = k.clone().unsqueeze(0).permute(0, 2, 1, 3).requires_grad_(True) + v_ref = v.clone().unsqueeze(0).permute(0, 2, 1, 3).requires_grad_(True) + num_q, num_kv = q_ref.shape[1], k_ref.shape[1] + if num_q != num_kv: + r = num_q // num_kv + k_exp = k_ref.repeat_interleave(r, dim=1) + v_exp = v_ref.repeat_interleave(r, dim=1) + else: + k_exp, v_exp = k_ref, v_ref + o_ref = F.scaled_dot_product_attention( + q_ref, k_exp, v_exp, is_causal=is_causal, scale=scale + ) + o_ref.sum().backward() + dq = q_ref.grad.permute(0, 2, 1, 3).squeeze(0) + dk = k_ref.grad.permute(0, 2, 1, 3).squeeze(0) + dv = v_ref.grad.permute(0, 2, 1, 3).squeeze(0) + return dq.detach(), dk.detach(), dv.detach() + + def test_dense_causal_matches_sdpa(self): + """dQ, dK, dV match SDPA for causal self-attention.""" + seq_len, num_heads, num_kv_heads, head_dim = 16, 2, 2, 32 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(42) + q, k, v = make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + locs, lens = make_varlen_meta([seq_len]) + + attention(q, k, v, locs, lens, seq_len, softmax_scale=scale).sum().backward() + dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref(q.detach(), k.detach(), v.detach(), scale) + + torch.testing.assert_close(q.grad, dq_ref, rtol=5e-3, atol=5e-3) + torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) + torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) + + def test_dense_gqa_matches_sdpa(self): + """Dense backward with GQA (4 q-heads, 2 kv-heads), seq_len=256.""" + seq_len, num_heads, num_kv_heads, head_dim = 256, 4, 2, 32 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(43) + q, k, v = make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + locs, lens = make_varlen_meta([seq_len]) + + attention(q, k, v, locs, lens, seq_len, softmax_scale=scale).sum().backward() + dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref(q.detach(), k.detach(), v.detach(), scale) + + torch.testing.assert_close(q.grad, dq_ref, rtol=5e-3, atol=5e-3) + torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) + torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) + + def test_dense_multi_batch_variable_length(self): + """Multi-batch variable-length backward matches per-sample SDPA.""" + seq_lens = [8, 12] + total = sum(seq_lens) + num_heads, num_kv_heads, head_dim = 2, 2, 32 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(45) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + locs, lens = make_varlen_meta(seq_lens) + + attention(q, k, v, locs, lens, max(seq_lens), softmax_scale=scale).sum().backward() + + dq_ref = torch.zeros_like(q) + dk_ref = torch.zeros_like(k) + dv_ref = torch.zeros_like(v) + for b in range(len(seq_lens)): + s, n = int(locs[b].item()), seq_lens[b] + dq_b, dk_b, dv_b = self._sdpa_backward_ref( + q.detach()[s : s + n], + k.detach()[s : s + n], + v.detach()[s : s + n], + scale, + ) + dq_ref[s : s + n] = dq_b + dk_ref[s : s + n] = dk_b + dv_ref[s : s + n] = dv_b + + torch.testing.assert_close(q.grad, dq_ref, rtol=5e-3, atol=5e-3) + torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) + torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) + + def test_dense_longer_sequences(self): + """Dense backward with seq_len=512, GQA, exercises multi-tile loops.""" + seq_len, num_heads, num_kv_heads, head_dim = 512, 4, 2, 64 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(49) + q, k, v = make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + locs, lens = make_varlen_meta([seq_len]) + + attention(q, k, v, locs, lens, seq_len, softmax_scale=scale).sum().backward() + dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref(q.detach(), k.detach(), v.detach(), scale) + + torch.testing.assert_close(q.grad, dq_ref, rtol=5e-3, atol=5e-3) + torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) + torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) + + +# --------------------------------------------------------------------------- +# HuggingFace integration +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestHFIntegration: + """HF model integration with Triton attention backend.""" def test_triton_matches_eager(self, tiny_llama_dir): """Triton attention produces same logits and generated tokens as eager.""" @@ -172,7 +269,6 @@ def test_triton_matches_eager(self, tiny_llama_dir): tok.pad_token_id = tok.eos_token_id ids = tok("The capital of France is", return_tensors="pt").input_ids.to("cuda") - # Eager baseline model_eager = AutoModelForCausalLM.from_pretrained( tiny_llama_dir, attn_implementation="eager", @@ -190,7 +286,6 @@ def test_triton_matches_eager(self, tiny_llama_dir): ) del model_eager - # Triton model_triton = AutoModelForCausalLM.from_pretrained( tiny_llama_dir, attn_implementation="modelopt_triton", @@ -207,15 +302,13 @@ def test_triton_matches_eager(self, tiny_llama_dir): pad_token_id=tok.pad_token_id, ) - # Logits should be close (bf16 tolerance) torch.testing.assert_close(logits_triton, logits_eager, rtol=2e-2, atol=2e-2) - # Generated tokens must be identical (greedy decoding is deterministic) assert torch.equal(out_triton, out_eager), ( f"Generated tokens differ:\n eager: {out_eager}\n triton: {out_triton}" ) def test_triton_padded_batch(self, tiny_llama_dir): - """Padded batch (2D attention mask) produces valid logits for each sequence.""" + """Padded batch produces valid logits.""" pytest.importorskip("transformers") from transformers import AutoModelForCausalLM, AutoTokenizer @@ -240,200 +333,61 @@ def test_triton_padded_batch(self, tiny_llama_dir): logits = model(**inputs).logits assert not torch.isnan(logits).any() and not torch.isinf(logits).any() + def test_sparse_nm_via_sparsify(self, tiny_llama_dir): + """mtsa.sparsify() with N:M sparse softmax produces finite logits that differ from dense.""" + pytest.importorskip("transformers") + from transformers import AutoModelForCausalLM, AutoTokenizer -@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") -class TestBackward: - """Backward pass gradient correctness tests.""" - - def _sdpa_backward_ref(self, q, k, v, scale, is_causal=True): - """Run SDPA forward+backward, return output and gradients.""" - q_ref = q.clone().unsqueeze(0).permute(0, 2, 1, 3).requires_grad_(True) - k_ref = k.clone().unsqueeze(0).permute(0, 2, 1, 3).requires_grad_(True) - v_ref = v.clone().unsqueeze(0).permute(0, 2, 1, 3).requires_grad_(True) - num_q, num_kv = q_ref.shape[1], k_ref.shape[1] - if num_q != num_kv: - r = num_q // num_kv - k_exp = k_ref.repeat_interleave(r, dim=1) - v_exp = v_ref.repeat_interleave(r, dim=1) - else: - k_exp, v_exp = k_ref, v_ref - o_ref = F.scaled_dot_product_attention( - q_ref, k_exp, v_exp, is_causal=is_causal, scale=scale - ) - o_ref.sum().backward() - dq = q_ref.grad.permute(0, 2, 1, 3).squeeze(0) - dk = k_ref.grad.permute(0, 2, 1, 3).squeeze(0) - dv = v_ref.grad.permute(0, 2, 1, 3).squeeze(0) - return o_ref.permute(0, 2, 1, 3).squeeze(0).detach(), dq.detach(), dk.detach(), dv.detach() - - def test_backward_causal_matches_sdpa(self): - """dQ, dK, dV match SDPA backward for causal self-attention.""" - from modelopt.torch.kernels import attention - - seq_len = 16 - num_heads, num_kv_heads, head_dim = 2, 2, 32 - scale = 1.0 / (head_dim**0.5) - - torch.manual_seed(42) - q = torch.randn( - seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - k = torch.randn( - seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - v = torch.randn( - seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - - o = attention( - q, - k, - v, - b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), - b_seq_len=torch.tensor([seq_len], device="cuda", dtype=torch.int32), - max_input_len=seq_len, - is_causal=True, - softmax_scale=scale, - ) - o.sum().backward() - - _, dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref( - q.detach(), k.detach(), v.detach(), scale, is_causal=True - ) - - torch.testing.assert_close(q.grad, dq_ref, rtol=5e-3, atol=5e-3) - torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) - torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) - - def test_backward_gqa(self): - """Backward with GQA (4 q-heads, 2 kv-heads), multi-tile (seq_len=256).""" - from modelopt.torch.kernels import attention - - seq_len = 256 - num_heads, num_kv_heads, head_dim = 4, 2, 32 - scale = 1.0 / (head_dim**0.5) - - torch.manual_seed(43) - q = torch.randn( - seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - k = torch.randn( - seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - v = torch.randn( - seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - - o = attention( - q, - k, - v, - b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), - b_seq_len=torch.tensor([seq_len], device="cuda", dtype=torch.int32), - max_input_len=seq_len, - is_causal=True, - softmax_scale=scale, - ) - o.sum().backward() - - _, dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref( - q.detach(), k.detach(), v.detach(), scale, is_causal=True - ) - - torch.testing.assert_close(q.grad, dq_ref, rtol=5e-3, atol=5e-3) - torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) - torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) - - def test_backward_multi_batch_variable_length(self): - """Multi-batch variable-length causal backward matches per-sample SDPA.""" - from modelopt.torch.kernels import attention - - seq_lens = [8, 12] - total = sum(seq_lens) - num_heads, num_kv_heads, head_dim = 2, 2, 32 - scale = 1.0 / (head_dim**0.5) - - torch.manual_seed(45) - q = torch.randn( - total, num_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - k = torch.randn( - total, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - v = torch.randn( - total, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - locs = torch.tensor([0, seq_lens[0]], device="cuda", dtype=torch.int32) - lens = torch.tensor(seq_lens, device="cuda", dtype=torch.int32) - - o = attention( - q, - k, - v, - b_start_loc=locs, - b_seq_len=lens, - max_input_len=max(seq_lens), - is_causal=True, - softmax_scale=scale, - ) - o.sum().backward() - - # Per-sample SDPA reference - dq_ref = torch.zeros_like(q) - dk_ref = torch.zeros_like(k) - dv_ref = torch.zeros_like(v) - for b in range(len(seq_lens)): - s, n = int(locs[b].item()), seq_lens[b] - _, dq_b, dk_b, dv_b = self._sdpa_backward_ref( - q.detach()[s : s + n], - k.detach()[s : s + n], - v.detach()[s : s + n], - scale, - is_causal=True, - ) - dq_ref[s : s + n] = dq_b - dk_ref[s : s + n] = dk_b - dv_ref[s : s + n] = dv_b - - torch.testing.assert_close(q.grad, dq_ref, rtol=5e-3, atol=5e-3) - torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) - torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) - - def test_backward_longer_sequences(self): - """Backward with seq_len=512, GQA, exercises multi-tile loops.""" - from modelopt.torch.kernels import attention + import modelopt.torch.sparsity.attention_sparsity as mtsa - seq_len = 512 - num_heads, num_kv_heads, head_dim = 4, 2, 64 - scale = 1.0 / (head_dim**0.5) + tok = AutoTokenizer.from_pretrained(tiny_llama_dir) + if tok.pad_token_id is None: + tok.pad_token_id = tok.eos_token_id + # Use a long input (fill max_position_embeddings=64) so sparsity has tiles to prune + ids = torch.randint(1, tok.vocab_size, (1, 64), device="cuda") - torch.manual_seed(49) - q = torch.randn( - seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - k = torch.randn( - seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True - ) - v = torch.randn( - seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32, requires_grad=True + # Dense baseline (triton backend, no sparsity) + model_dense = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + attn_implementation="modelopt_triton", + torch_dtype=torch.bfloat16, + device_map="cuda", ) - - o = attention( - q, - k, - v, - b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), - b_seq_len=torch.tensor([seq_len], device="cuda", dtype=torch.int32), - max_input_len=seq_len, - is_causal=True, - softmax_scale=scale, + model_dense.eval() + with torch.no_grad(): + logits_dense = model_dense(input_ids=ids).logits + del model_dense + + # Sparse via mtsa.sparsify() with dense_window_size=0 to force sparsity on all tiles + sparse_cfg = { + "sparse_cfg": { + "*attn*": { + "method": "triton_sparse_softmax", + "sparsity_n": 2, + "sparsity_m": 4, + "num_sink_tokens": 0, + "dense_window_size": 0, + "backend": "triton", + "enable": True, + }, + "default": {"enable": False}, + }, + } + model_sparse = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + torch_dtype=torch.bfloat16, + device_map="cuda", ) - o.sum().backward() - - _, dq_ref, dk_ref, dv_ref = self._sdpa_backward_ref( - q.detach(), k.detach(), v.detach(), scale, is_causal=True + mtsa.sparsify(model_sparse, sparse_cfg) + assert model_sparse.config._attn_implementation == "modelopt_triton" + model_sparse.eval() + with torch.no_grad(): + logits_sparse = model_sparse(input_ids=ids).logits + + # Sparse output should be finite + assert not torch.isnan(logits_sparse).any(), "NaN in sparse logits" + assert not torch.isinf(logits_sparse).any(), "Inf in sparse logits" + # Sparse output should differ from dense (sparsity changes attention) + assert not torch.allclose(logits_sparse, logits_dense, atol=1e-2), ( + "Sparse logits identical to dense — sparsity may not be applied" ) - - torch.testing.assert_close(q.grad, dq_ref, rtol=5e-3, atol=5e-3) - torch.testing.assert_close(k.grad, dk_ref, rtol=5e-3, atol=5e-3) - torch.testing.assert_close(v.grad, dv_ref, rtol=5e-3, atol=5e-3) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py new file mode 100644 index 000000000..7fd961a41 --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py @@ -0,0 +1,348 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: N803 — Triton JIT wrapper uses uppercase for constexpr and tensor args + +"""GPU tests for N:M sparse softmax on the Triton flash attention kernel.""" + +import pytest +import torch +from conftest import make_qkv, make_varlen_meta + +pytestmark = [ + pytest.mark.filterwarnings("ignore::UserWarning"), + pytest.mark.filterwarnings("ignore::RuntimeWarning"), + pytest.mark.filterwarnings("ignore::DeprecationWarning"), +] + +from modelopt.torch.kernels import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE + +if TRITON_KERNEL_AVAILABLE: + import triton + import triton.language as tl + + from modelopt.torch.kernels import attention + from modelopt.torch.kernels.triton_fa import _apply_sparse_nm_to_qk_tile + + @triton.jit + def _test_apply_sparse_nm( + In, + Out, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + SPARSITY_N: tl.constexpr, + SPARSITY_M: tl.constexpr, + ): + """Test wrapper: apply N:M sparsity to a tile and store result.""" + offs = tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] + qk = tl.load(In + offs) + tl.store( + Out + offs, + _apply_sparse_nm_to_qk_tile(qk, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M), + ) + + +# --------------------------------------------------------------------------- +# N:M sparsity behavior (prefill only) +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestSparseNM: + """N:M sparse softmax behavior on attention scores.""" + + def _make_inputs(self, batch=2, seq_len=256, num_heads=4, num_kv_heads=2, head_dim=64): + total = batch * seq_len + torch.manual_seed(99) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta([seq_len] * batch) + return q, k, v, locs, lens + + @pytest.mark.parametrize( + ("n", "m"), + [(1, 4), (2, 4), (3, 4), (1, 8), (2, 8), (4, 8)], + ids=["1:4", "2:4", "3:4", "1:8", "2:8", "4:8"], + ) + def test_output_shape(self, n, m): + """Output shape matches Q shape for all N:M patterns.""" + q, k, v, locs, lens = self._make_inputs() + out = attention( + q, k, v, locs, lens, 256, softmax_scale=1.0 / 8.0, sparsity_n=n, sparsity_m=m + ) + assert out.shape == q.shape + + @pytest.mark.parametrize( + ("n", "m"), + [(1, 4), (2, 4), (3, 4), (1, 8), (2, 8), (4, 8)], + ids=["1:4", "2:4", "3:4", "1:8", "2:8", "4:8"], + ) + def test_no_nan(self, n, m): + """All N:M patterns produce finite output.""" + q, k, v, locs, lens = self._make_inputs() + out = attention( + q, k, v, locs, lens, 256, softmax_scale=1.0 / 8.0, sparsity_n=n, sparsity_m=m + ) + assert not torch.isnan(out).any(), f"NaN in output for {n}:{m}" + assert not torch.isinf(out).any(), f"Inf in output for {n}:{m}" + + @pytest.mark.parametrize( + ("n", "m"), + [(1, 4), (2, 4), (1, 8), (4, 8)], + ids=["1:4", "2:4", "1:8", "4:8"], + ) + def test_sparse_differs_from_dense(self, n, m): + """Sparse output should differ from dense for long sequences.""" + q, k, v, locs, lens = self._make_inputs(seq_len=512) + scale = 1.0 / (64**0.5) + out_dense = attention(q, k, v, locs, lens, 512, softmax_scale=scale) + out_sparse = attention( + q, k, v, locs, lens, 512, softmax_scale=scale, sparsity_n=n, sparsity_m=m + ) + assert not torch.allclose(out_sparse, out_dense, atol=1e-3) + + @pytest.mark.parametrize( + ("n_values", "m"), + [([1, 2, 3], 4), ([1, 2, 4], 8)], + ids=["m4", "m8"], + ) + def test_more_sparsity_more_error(self, n_values, m): + """Keeping more elements should deviate less from dense (monotonic decreasing error).""" + q, k, v, locs, lens = self._make_inputs(seq_len=512) + scale = 1.0 / (64**0.5) + out_dense = attention(q, k, v, locs, lens, 512, softmax_scale=scale) + errors = [] + for n in n_values: + out = attention( + q, k, v, locs, lens, 512, softmax_scale=scale, sparsity_n=n, sparsity_m=m + ) + errors.append((out - out_dense).abs().mean().item()) + for i in range(len(errors) - 1): + assert errors[i] > errors[i + 1], ( + f"Errors not monotonically decreasing for M={m}: " + + ", ".join(f"{n}:{m}={e:.6f}" for n, e in zip(n_values, errors)) + ) + + @pytest.mark.parametrize( + ("n", "m"), + [(2, 4), (4, 8)], + ids=["2:4", "4:8"], + ) + def test_dense_window_preserves_local(self, n, m): + """Large dense_window_size makes sparse output closer to dense.""" + q, k, v, locs, lens = self._make_inputs(seq_len=256) + scale = 1.0 / (64**0.5) + out_dense = attention(q, k, v, locs, lens, 256, softmax_scale=scale) + out_small = attention( + q, + k, + v, + locs, + lens, + 256, + softmax_scale=scale, + sparsity_n=n, + sparsity_m=m, + dense_window_size=64, + ) + out_large = attention( + q, + k, + v, + locs, + lens, + 256, + softmax_scale=scale, + sparsity_n=n, + sparsity_m=m, + dense_window_size=100000, + ) + err_small = (out_small - out_dense).abs().mean().item() + err_large = (out_large - out_dense).abs().mean().item() + assert err_large < err_small + + @pytest.mark.parametrize( + ("n", "m"), + [(2, 4), (4, 8)], + ids=["2:4", "4:8"], + ) + def test_sink_tokens_preserve_early_kv(self, n, m): + """num_sink_tokens keeps early KV positions dense, reducing error vs fully sparse.""" + q, k, v, locs, lens = self._make_inputs(seq_len=512) + scale = 1.0 / (64**0.5) + out_dense = attention(q, k, v, locs, lens, 512, softmax_scale=scale) + out_no_sink = attention( + q, + k, + v, + locs, + lens, + 512, + softmax_scale=scale, + sparsity_n=n, + sparsity_m=m, + num_sink_tokens=0, + ) + out_with_sink = attention( + q, + k, + v, + locs, + lens, + 512, + softmax_scale=scale, + sparsity_n=n, + sparsity_m=m, + num_sink_tokens=128, + ) + err_no_sink = (out_no_sink - out_dense).abs().mean().item() + err_with_sink = (out_with_sink - out_dense).abs().mean().item() + assert err_with_sink < err_no_sink, ( + f"Sink tokens should reduce error: no_sink={err_no_sink:.6f}, with_sink={err_with_sink:.6f}" + ) + + # NOTE: N:M sparse attention is for prefill only, not decode. + + +# --------------------------------------------------------------------------- +# Sparsity tile structure +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestSparseTileStructure: + """Direct unit tests for _apply_sparse_nm_to_qk_tile via wrapper kernel.""" + + @pytest.mark.parametrize( + ("n", "m"), + [(1, 4), (2, 4), (3, 4), (1, 8), (2, 8), (4, 8)], + ids=["1:4", "2:4", "3:4", "1:8", "2:8", "4:8"], + ) + def test_sparsity_structure(self, n, m): + """Verify N:M structure: exactly N kept per group of M.""" + bm, bn = 32, 64 + torch.manual_seed(88) + tile = torch.randn(bm, bn, device="cuda", dtype=torch.float32) + out = torch.empty_like(tile) + _test_apply_sparse_nm[(1,)](tile, out, BLOCK_M=bm, BLOCK_N=bn, SPARSITY_N=n, SPARSITY_M=m) + + kept = (out.reshape(bm, bn // m, m) != float("-inf")).sum(dim=-1) + assert (kept == n).all(), ( + f"Expected {n} kept per group of {m}, got min={kept.min()}, max={kept.max()}" + ) + + @pytest.mark.parametrize( + ("n", "m"), + [(2, 4), (4, 8)], + ids=["2:4", "4:8"], + ) + def test_sparsity_structure_ties(self, n, m): + """M=4 keeps exactly N on ties; M=8 (tl.sort) may keep >= N on ties.""" + bm, bn = 32, 64 + tile = torch.ones(bm, bn, device="cuda", dtype=torch.float32) + out = torch.empty_like(tile) + _test_apply_sparse_nm[(1,)](tile, out, BLOCK_M=bm, BLOCK_N=bn, SPARSITY_N=n, SPARSITY_M=m) + + kept = (out.reshape(bm, bn // m, m) != float("-inf")).sum(dim=-1) + if m == 4: + assert (kept == n).all(), ( + f"M=4 tie: expected {n}, got min={kept.min()}, max={kept.max()}" + ) + else: + assert (kept >= n).all(), f"M=8 tie: expected >= {n}, got min={kept.min()}" + assert (kept <= m).all(), f"M=8 tie: expected <= {m}, got max={kept.max()}" + + +# --------------------------------------------------------------------------- +# Sparse backward sanity +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestSparseBackward: + """Backward pass sanity checks with N:M sparsity enabled.""" + + @pytest.mark.parametrize( + ("n", "m"), + [(2, 4), (4, 8)], + ids=["2:4", "4:8"], + ) + def test_sparse_gradients_finite(self, n, m): + """Backward with N:M sparsity produces finite, non-zero gradients.""" + seq_len, num_heads, num_kv_heads, head_dim = 128, 4, 2, 64 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(55) + q, k, v = make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + locs, lens = make_varlen_meta([seq_len]) + + attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + sparsity_n=n, + sparsity_m=m, + ).sum().backward() + + for name, grad in [("dQ", q.grad), ("dK", k.grad), ("dV", v.grad)]: + assert grad is not None, f"{name} is None for {n}:{m}" + assert not torch.isnan(grad).any(), f"NaN in {name} for {n}:{m}" + assert not torch.isinf(grad).any(), f"Inf in {name} for {n}:{m}" + assert grad.abs().sum() > 0, f"{name} is all zeros for {n}:{m}" + + def test_sparse_gradients_differ_from_dense(self): + """Gradients with 2:4 sparsity should differ from dense gradients.""" + seq_len, num_heads, num_kv_heads, head_dim = 256, 4, 2, 64 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(66) + q, k, v = make_qkv(seq_len, num_heads, num_kv_heads, head_dim, dtype=torch.float32) + locs, lens = make_varlen_meta([seq_len]) + + q_d = q.clone().requires_grad_(True) + k_d = k.clone().requires_grad_(True) + v_d = v.clone().requires_grad_(True) + attention(q_d, k_d, v_d, locs, lens, seq_len, softmax_scale=scale).sum().backward() + + q_s = q.clone().requires_grad_(True) + k_s = k.clone().requires_grad_(True) + v_s = v.clone().requires_grad_(True) + attention( + q_s, + k_s, + v_s, + locs, + lens, + seq_len, + softmax_scale=scale, + sparsity_n=2, + sparsity_m=4, + ).sum().backward() + + assert not torch.allclose(q_d.grad, q_s.grad, atol=1e-3), ( + "dQ same with and without sparsity" + ) + assert not torch.allclose(k_d.grad, k_s.grad, atol=1e-3), ( + "dK same with and without sparsity" + ) + assert not torch.allclose(v_d.grad, v_s.grad, atol=1e-3), ( + "dV same with and without sparsity" + )