Skip to content

Replace MLA Triton/TileLang kernels with flash-attn-4#3

Draft
haok1402 wants to merge 2 commits intomlc-ai:mainfrom
haok1402:0329-flash-attn-4
Draft

Replace MLA Triton/TileLang kernels with flash-attn-4#3
haok1402 wants to merge 2 commits intomlc-ai:mainfrom
haok1402:0329-flash-attn-4

Conversation

@haok1402
Copy link
Copy Markdown
Contributor

@haok1402 haok1402 commented Mar 29, 2026

Kernel Performance on NVIDIA H100

MLA: flash-attn-4 vs TileLang (h=16, dq=192, dv=128)

Forward

Seq Len FA4 (ms) FA4 TFLOPS TileLang (ms) TileLang TFLOPS FA4 speedup
2048 0.051 422.47 0.073 293.83 1.4x
4096 0.150 571.90 0.197 436.74 1.3x
8192 0.563 610.64 0.656 523.51 1.2x
16384 2.169 633.59 2.525 544.25 1.2x
32768 8.824 623.04 9.420 583.62 1.1x

Backward

Seq Len FA4 (ms) FA4 TFLOPS TileLang (ms) TileLang TFLOPS FA4 speedup
2048 0.199 270.04 0.321 167.34 1.6x
4096 0.577 372.05 0.904 237.45 1.6x
8192 1.974 435.14 2.883 297.99 1.5x
16384 7.532 456.21 9.905 346.90 1.3x
32768 28.710 478.71 39.717 346.04 1.4x

Kernel Performance on NVIDIA B200

MLA: flash-attn-4 vs Triton MLA (h=16, dq=192, dv=128)

Forward

Seq Len FA4 (ms) FA4 TFLOPS Triton (ms) Triton TFLOPS FA4 speedup
2048 0.040 536.81 0.099 216.61 2.5x
4096 0.075 1147.63 0.249 344.96 3.3x
8192 0.261 1318.74 0.748 459.41 2.9x
16384 1.012 1358.72 2.535 542.06 2.5x
32768 3.915 1404.26 9.296 591.41 2.4x

Backward

Seq Len FA4 (ms) FA4 TFLOPS Triton (ms) Triton TFLOPS FA4 speedup
2048 0.130 414.01 0.296 181.53 2.3x
4096 0.330 650.25 0.777 276.36 2.4x
8192 1.031 833.26 2.330 368.70 2.3x
16384 3.628 947.15 8.060 426.28 2.2x
32768 13.813 995.03 29.895 459.74 2.2x

End-to-End Performance on NVIDIA H100

Model: deepseek-v2-lite; Device: 1x8-h100; Mesh: 2pp-2dp-1cp-2ep; Sequence Length: 2048; Dtype: bf16.

Before the change, using MLA with TileLang,

2026-03-31 22:17:25 | INFO | step 00000016/00000025 | step-time 24.666 sec | cross-entropy-loss 2.4565 | load-balance-loss 0.003234 | learning-rate 1.000000e-06 | gradient-norm 2.1498 | tokens-per-second 85,021 | peak-gpu-memory 57.22 GB
2026-03-31 22:17:49 | INFO | step 00000017/00000025 | step-time 24.675 sec | cross-entropy-loss 2.4274 | load-balance-loss 0.003240 | learning-rate 1.000000e-06 | gradient-norm 1.8743 | tokens-per-second 84,993 | peak-gpu-memory 57.19 GB
2026-03-31 22:18:15 | INFO | step 00000018/00000025 | step-time 24.741 sec | cross-entropy-loss 2.4353 | load-balance-loss 0.003229 | learning-rate 1.000000e-06 | gradient-norm 1.4246 | tokens-per-second 84,763 | peak-gpu-memory 57.20 GB
2026-03-31 22:18:39 | INFO | step 00000019/00000025 | step-time 24.536 sec | cross-entropy-loss 2.4145 | load-balance-loss 0.003224 | learning-rate 1.000000e-06 | gradient-norm 1.3992 | tokens-per-second 85,472 | peak-gpu-memory 57.19 GB
2026-03-31 22:19:04 | INFO | step 00000020/00000025 | step-time 24.554 sec | cross-entropy-loss 2.4181 | load-balance-loss 0.003223 | learning-rate 1.000000e-06 | gradient-norm 1.3638 | tokens-per-second 85,410 | peak-gpu-memory 57.19 GB
2026-03-31 22:19:29 | INFO | step 00000021/00000025 | step-time 24.634 sec | cross-entropy-loss 2.4449 | load-balance-loss 0.003224 | learning-rate 1.000000e-06 | gradient-norm 1.3544 | tokens-per-second 85,134 | peak-gpu-memory 57.19 GB
2026-03-31 22:19:54 | INFO | step 00000022/00000025 | step-time 24.609 sec | cross-entropy-loss 2.4287 | load-balance-loss 0.003230 | learning-rate 1.000000e-06 | gradient-norm 1.2901 | tokens-per-second 85,220 | peak-gpu-memory 57.17 GB
2026-03-31 22:20:19 | INFO | step 00000023/00000025 | step-time 24.639 sec | cross-entropy-loss 2.3759 | load-balance-loss 0.003223 | learning-rate 1.000000e-06 | gradient-norm 1.2806 | tokens-per-second 85,117 | peak-gpu-memory 57.19 GB
2026-03-31 22:20:44 | INFO | step 00000024/00000025 | step-time 24.556 sec | cross-entropy-loss 2.4278 | load-balance-loss 0.003230 | learning-rate 1.000000e-06 | gradient-norm 1.3140 | tokens-per-second 85,402 | peak-gpu-memory 57.16 GB
2026-03-31 22:21:09 | INFO | step 00000025/00000025 | step-time 24.641 sec | cross-entropy-loss 2.4256 | load-balance-loss 0.003226 | learning-rate 1.000000e-06 | gradient-norm 1.2105 | tokens-per-second 85,108 | peak-gpu-memory 57.16 GB

After the change, using flash-attn-4,

2026-03-31 22:03:19 | INFO | step 00000016/00000025 | step-time 25.491 sec | cross-entropy-loss 2.4564 | load-balance-loss 0.003234 | learning-rate 1.000000e-06 | gradient-norm 2.1494 | tokens-per-second 82,269 | peak-gpu-memory 57.20 GB
2026-03-31 22:03:45 | INFO | step 00000017/00000025 | step-time 25.391 sec | cross-entropy-loss 2.4275 | load-balance-loss 0.003240 | learning-rate 1.000000e-06 | gradient-norm 1.8759 | tokens-per-second 82,595 | peak-gpu-memory 57.15 GB
2026-03-31 22:04:11 | INFO | step 00000018/00000025 | step-time 25.427 sec | cross-entropy-loss 2.4353 | load-balance-loss 0.003229 | learning-rate 1.000000e-06 | gradient-norm 1.4266 | tokens-per-second 82,476 | peak-gpu-memory 57.19 GB
2026-03-31 22:04:37 | INFO | step 00000019/00000025 | step-time 25.358 sec | cross-entropy-loss 2.4145 | load-balance-loss 0.003224 | learning-rate 1.000000e-06 | gradient-norm 1.3993 | tokens-per-second 82,701 | peak-gpu-memory 57.18 GB
2026-03-31 22:05:02 | INFO | step 00000020/00000025 | step-time 25.394 sec | cross-entropy-loss 2.4180 | load-balance-loss 0.003223 | learning-rate 1.000000e-06 | gradient-norm 1.3634 | tokens-per-second 82,585 | peak-gpu-memory 57.16 GB
2026-03-31 22:05:28 | INFO | step 00000021/00000025 | step-time 25.381 sec | cross-entropy-loss 2.4450 | load-balance-loss 0.003224 | learning-rate 1.000000e-06 | gradient-norm 1.3525 | tokens-per-second 82,626 | peak-gpu-memory 57.18 GB
2026-03-31 22:05:54 | INFO | step 00000022/00000025 | step-time 25.386 sec | cross-entropy-loss 2.4287 | load-balance-loss 0.003230 | learning-rate 1.000000e-06 | gradient-norm 1.2901 | tokens-per-second 82,611 | peak-gpu-memory 57.17 GB
2026-03-31 22:06:19 | INFO | step 00000023/00000025 | step-time 25.318 sec | cross-entropy-loss 2.3759 | load-balance-loss 0.003223 | learning-rate 1.000000e-06 | gradient-norm 1.2784 | tokens-per-second 82,832 | peak-gpu-memory 57.16 GB
2026-03-31 22:06:45 | INFO | step 00000024/00000025 | step-time 25.384 sec | cross-entropy-loss 2.4278 | load-balance-loss 0.003230 | learning-rate 1.000000e-06 | gradient-norm 1.3182 | tokens-per-second 82,618 | peak-gpu-memory 57.14 GB
2026-03-31 22:07:11 | INFO | step 00000025/00000025 | step-time 25.469 sec | cross-entropy-loss 2.4256 | load-balance-loss 0.003226 | learning-rate 1.000000e-06 | gradient-norm 1.2094 | tokens-per-second 82,343 | peak-gpu-memory 57.14 GB

We observe a slight throughput regression (~3.0%) with flash-attn-4 on Hopper (H100), averaging ~82.60K tokens/second compared to ~85.11K tokens/second with MLA TileLang.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request removes the custom Multi-Head Latent Attention (MLA) operator implementations across PyTorch, Triton, and TileLang, along with their related benchmarks and tests. The DeepSeek-V2 model is updated to utilize the standard flash_attn_v4 and ring_attention_func operators, which have been enhanced to support the asymmetric head dimensions characteristic of MLA. Feedback was provided regarding the necessity of ensuring that input tensors are contiguous before calling these attention functions to prevent potential runtime errors or incorrect results.

@haok1402
Copy link
Copy Markdown
Contributor Author

On B200, running into issues with torch.compile on flash-attn-4 MLA. Investigation ongoing.

@haok1402 haok1402 force-pushed the 0329-flash-attn-4 branch 3 times, most recently from b4068ff to 43ac8e5 Compare April 1, 2026 01:51
…SM100 codegen bug that leads to NaN

Made-with: Cursor
@haok1402 haok1402 force-pushed the 0329-flash-attn-4 branch from 43ac8e5 to 1994a83 Compare April 1, 2026 01:52
@haok1402
Copy link
Copy Markdown
Contributor Author

haok1402 commented Apr 1, 2026

On B200, running into issues with torch.compile on flash-attn-4 MLA. Investigation ongoing.

As a temporary workaround to this bug with the Inductor on SM100, we move the concat of (q_nope, q_pe) and (k_nope, k_pe) into flash-attention MLA operator to stay opaque, which resolves the Nan issue in the gradient. Interesting enough, such Inductor bug didn't occur on SM90.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant