Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ NVIDIA Model Optimizer Changelog

- Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics.
- Added iterator interface using CalibrationDataReader in ONNX quantization workflow.
- Add N:M sparse softmax support to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.
- Enable PTQ workflow for the Step3.5-Flash MoE model with NVFP4 W4A4 + FP8 KV cache quantization. See `modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml>`_ for more details.

**Bug Fixes**
Expand Down
58 changes: 52 additions & 6 deletions examples/llm_sparsity/attention_sparsity/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Attention Sparsity for HuggingFace Models

In this tutorial, we demonstrate how to use NVIDIA Model Optimizer to apply attention sparsity to HuggingFace models. Attention sparsity reduces computational cost by skipping near-zero attention scores during the softmax computation. Two attention backends are supported:
In this tutorial, we demonstrate how to use NVIDIA Model Optimizer to apply attention sparsity to HuggingFace models. Two sparsity methods are supported:

- **Skip-softmax** (`flash_skip_softmax`): Skips attention tiles whose contribution is negligible, based on a threshold. Based on the [BLASST](https://arxiv.org/pdf/2512.12087) algorithm.
- **N:M sparse softmax** (`triton_sparse_softmax`): For every M consecutive key positions, keeps the top-N attention scores and sets the rest to -inf before softmax.

Two attention backends are available:

- **pytorch** (default): Patches `F.softmax` to apply skip-softmax sparsity (requires `attn_implementation="eager"`)
- **triton**: Uses a fused Triton Flash Attention kernel with in-kernel sparsity (uses `attn_implementation="modelopt_triton"`)
Expand Down Expand Up @@ -29,9 +34,9 @@ model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT)

## Configuration Options

Two pre-defined configurations are available:
### Skip-Softmax

### 1. Fixed Threshold (SKIP_SOFTMAX_DEFAULT)
#### 1. Fixed Threshold (SKIP_SOFTMAX_DEFAULT)

Uses a fixed threshold value. Simple but may not be optimal for all sequence lengths.

Expand All @@ -41,7 +46,7 @@ from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAU
model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT)
```

### 2. Calibrated Threshold (SKIP_SOFTMAX_CALIB)
#### 2. Calibrated Threshold (SKIP_SOFTMAX_CALIB)

Uses RULER-based calibration to determine an optimal dynamic threshold that adapts to sequence length. Recommended for production use.

Expand All @@ -51,6 +56,46 @@ from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_CALIB
model = mtsa.sparsify(model, config=SKIP_SOFTMAX_CALIB)
```

### N:M Sparse Softmax (SPARSE_SOFTMAX_DEFAULT)

Applies N:M structured sparsity to attention scores using the Triton backend. For every M consecutive key positions, keeps only the top-N scores and sets the rest to -inf. Supports M=4 (N=1,2,3) and M=8 (N=1..7). Attention sinks and a local dense window can be configured to preserve important positions.

```python
from modelopt.torch.sparsity.attention_sparsity.config import SPARSE_SOFTMAX_DEFAULT

model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
torch_dtype=torch.bfloat16,
device_map="cuda",
)

model = mtsa.sparsify(model, config=SPARSE_SOFTMAX_DEFAULT)
```

Custom N:M configuration:

```python
sparse_cfg = {
"sparse_cfg": {
"*attn*": {
"method": "triton_sparse_softmax",
"sparsity_n": 2, # Keep top-2 of every 4
"sparsity_m": 4, # Group size
"num_sink_tokens": 4, # Keep first 4 tokens dense (attention sinks)
"dense_window_size": 128, # Keep tokens within distance 128 dense
"backend": "triton",
"enable": True,
},
"default": {"enable": False},
},
}

model = mtsa.sparsify(model, config=sparse_cfg)
```

> [!Note]
> N:M sparse softmax requires the Triton backend (`backend="triton"`). The `attn_implementation` is automatically set to `"modelopt_triton"` by `mtsa.sparsify()`. N:M sparsity is applied during prefill only — decode tokens are not sparsified.

## Prerequisites

### Local Installation
Expand Down Expand Up @@ -104,8 +149,8 @@ The calibration process:
| Argument | Default | Description |
|----------|---------|-------------|
| `--pyt_ckpt_path` | Required | HuggingFace model path or name |
| `--sparse_attn` | `skip_softmax` | Configuration: `skip_softmax` or `skip_softmax_calib` |
| `--backend` | `pytorch` | Backend: `pytorch` (only supported backend) |
| `--sparse_attn` | `skip_softmax` | Configuration: `skip_softmax`, `skip_softmax_calib`, or `sparse_softmax` |
| `--backend` | `pytorch` | Backend: `pytorch` (skip-softmax) or `triton` (N:M sparse softmax) |
| `--seq_len` | `2048` | Maximum sequence length for input prompts |
| `--export_dir` | `None` | Directory to export the sparsified model |

Expand Down Expand Up @@ -166,3 +211,4 @@ model = mtsa.sparsify(model, config=custom_config)

- [Model Optimizer Documentation](https://nvidia.github.io/Model-Optimizer/)
- [RULER: What's the Real Context Size of Your Long-Context Language Models?](https://github.com/NVIDIA/RULER)
- [BLASST: Block-Level Adaptive Structured Sparse Training](https://arxiv.org/pdf/2512.12087) — skip-softmax algorithm
14 changes: 9 additions & 5 deletions examples/llm_sparsity/attention_sparsity/hf_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from modelopt.torch.sparsity.attention_sparsity.config import (
SKIP_SOFTMAX_CALIB,
SKIP_SOFTMAX_DEFAULT,
SPARSE_SOFTMAX_DEFAULT,
)
from modelopt.torch.utils.memory_monitor import launch_memory_monitor

Expand All @@ -43,6 +44,7 @@
SPARSE_ATTN_CFG_CHOICES = {
"skip_softmax": SKIP_SOFTMAX_DEFAULT,
"skip_softmax_calib": SKIP_SOFTMAX_CALIB,
"sparse_softmax": SPARSE_SOFTMAX_DEFAULT,
}


Expand Down Expand Up @@ -168,9 +170,10 @@ def main(args):

# Apply CLI overrides to sparse_cfg
sparse_cfg = sparse_config.get("sparse_cfg", {})
for layer_cfg in sparse_cfg.values():
if isinstance(layer_cfg, dict) and "method" in layer_cfg:
layer_cfg["backend"] = args.backend
if args.backend is not None:
for layer_cfg in sparse_cfg.values():
if isinstance(layer_cfg, dict) and "method" in layer_cfg:
layer_cfg["backend"] = args.backend
if args.target_sparse_ratio is not None:
calib = sparse_cfg.setdefault("calibration", {})
assert isinstance(calib, dict)
Expand Down Expand Up @@ -240,9 +243,10 @@ def main(args):
parser.add_argument(
"--backend",
type=str,
default="pytorch",
default=None,
choices=["pytorch", "triton"],
help="Backend for sparse attention (default: pytorch). 'triton' uses the fused Triton kernel.",
help="Backend for sparse attention. Overrides the config default if set. "
"'triton' uses the fused Triton kernel.",
)

# Sequence length arguments
Expand Down
11 changes: 11 additions & 0 deletions modelopt/torch/kernels/hf_triton_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,17 @@ def triton_attention_forward(
kw["b_seq_len_k"] = torch.full((batch,), seq_k, device=device, dtype=torch.int32)
kw["max_input_len_k"] = seq_k

# N:M sparse softmax — prefill only (decode should not sparsify KV)
if not is_decode and getattr(module, "_apply_sparse_nm", False):
# _sparse_method_instance is set by SparseAttentionModule._init_sparse_method()
# in modelopt/torch/sparsity/attention_sparsity/sparse_attention.py
method = getattr(module, "_sparse_method_instance", None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where _sparse_method_instance gets set, if it's outside this PR, please add a comment.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

if method is not None:
kw["sparsity_n"] = getattr(method, "sparsity_n", 2)
kw["sparsity_m"] = getattr(method, "sparsity_m", 4)
kw["num_sink_tokens"] = getattr(method, "num_sink_tokens", 0)
kw["dense_window_size"] = getattr(method, "dense_window_size", 64)

o = attention(q, k, v, **kw)

attn_output = o.view(batch, seq_len, num_heads, head_dim)
Expand Down
Loading
Loading