Replace MLA Triton/TileLang kernels with flash-attn-4#3
Replace MLA Triton/TileLang kernels with flash-attn-4#3haok1402 wants to merge 2 commits intomlc-ai:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request removes the custom Multi-Head Latent Attention (MLA) operator implementations across PyTorch, Triton, and TileLang, along with their related benchmarks and tests. The DeepSeek-V2 model is updated to utilize the standard flash_attn_v4 and ring_attention_func operators, which have been enhanced to support the asymmetric head dimensions characteristic of MLA. Feedback was provided regarding the necessity of ensuring that input tensors are contiguous before calling these attention functions to prevent potential runtime errors or incorrect results.
|
On B200, running into issues with |
b4068ff to
43ac8e5
Compare
…SM100 codegen bug that leads to NaN Made-with: Cursor
43ac8e5 to
1994a83
Compare
As a temporary workaround to this bug with the Inductor on SM100, we move the concat of ( |
Kernel Performance on NVIDIA H100
MLA: flash-attn-4 vs TileLang (h=16, dq=192, dv=128)
Forward
Backward
Kernel Performance on NVIDIA B200
MLA: flash-attn-4 vs Triton MLA (h=16, dq=192, dv=128)
Forward
Backward
End-to-End Performance on NVIDIA H100
Model: deepseek-v2-lite; Device: 1x8-h100; Mesh: 2pp-2dp-1cp-2ep; Sequence Length: 2048; Dtype: bf16.
Before the change, using MLA with TileLang,
After the change, using flash-attn-4,
We observe a slight throughput regression (~3.0%) with flash-attn-4 on Hopper (H100), averaging ~82.60K tokens/second compared to ~85.11K tokens/second with MLA TileLang.