Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions examples/llm_sparsity/attention_sparsity/hf_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@
import modelopt.torch.opt as mto
import modelopt.torch.sparsity.attention_sparsity as mtsa
from modelopt.torch.export import export_hf_checkpoint
from modelopt.torch.sparsity.attention_sparsity.config import (
SKIP_SOFTMAX_CALIB,
SKIP_SOFTMAX_DEFAULT,
SPARSE_SOFTMAX_DEFAULT,
)
from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_CALIB
from modelopt.torch.utils.memory_monitor import launch_memory_monitor

RAND_SEED = 1234
Expand All @@ -42,9 +38,7 @@

# Sparse attention configuration choices
SPARSE_ATTN_CFG_CHOICES = {
"skip_softmax": SKIP_SOFTMAX_DEFAULT,
"skip_softmax_calib": SKIP_SOFTMAX_CALIB,
"sparse_softmax": SPARSE_SOFTMAX_DEFAULT,
}


Expand Down Expand Up @@ -236,7 +230,7 @@ def main(args):
parser.add_argument(
"--sparse_attn",
type=str,
default="skip_softmax",
default="skip_softmax_calib",
choices=list(SPARSE_ATTN_CFG_CHOICES.keys()),
help="Sparse attention configuration to apply.",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from _test_utils.examples.run_command import extend_cmd_parts, run_example_command


def run_attention_sparsity_command(*, model: str, method: str = "skip_softmax", **kwargs):
def run_attention_sparsity_command(*, model: str, method: str = "skip_softmax_calib", **kwargs):
"""Run attention sparsity example script.

Args:
Expand All @@ -40,7 +40,7 @@ def run_attention_sparsity_command(*, model: str, method: str = "skip_softmax",
run_example_command(cmd_parts, "llm_sparsity/attention_sparsity")


@pytest.mark.parametrize("method", ["skip_softmax"])
@pytest.mark.parametrize("method", ["skip_softmax_calib"])
def test_attention_sparsity(tiny_llama_path, tmp_path, method):
"""Test sparse attention with TinyLlama (with and without calibration)."""
run_attention_sparsity_command(
Expand Down
Loading