-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Add turbomind_rms_norm to accelerate QK norm in Qwen3 models #13189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
Looks good. Have you tested it on H100/B200? If not, I can help you add some H100/B200 tests. |
Thanks for reviewing! I don’t have access to H100/B200, so I couldn't test on those. |
|
H100 results, looks pretty good |
Motivation
For RMSNorm with head_dim <= 128, the QK-norm implementation in lmdeploy performs better than the current flashinfer RMSNorm implementation.
In benchmark on H20 (head_dim = 128, head_num = 48, token_num = 4096), the latency is reduced from 269 µs to 69 µs.
Modifications
turbomind_rms_norm.turbomind_rms_normwhen hidden_size <= 128.turbomind_rms_norm.Accuracy Tests
Unit tests
sgl-kernel/tests/test_norm.pypass.On H20, Model Qwen/Qwen3-8B-FP8:
python3 -m sglang.test.few_shot_gsm8k --num-questions 1000before:
after:
Benchmarking and Profiling
kernel benchmark (on H20):
python /sgl-workspace/sglang/sgl-kernel/benchmark/bench_turbomind_rmsnorm.pyresults:
e2e benchmark (on H20, model: Qwen/Qwen3-0.6B-FP8):
results(baseline, rmsnorm):
results(turbomind_rms_norm):
Checklist