Skip to content

Commit aa8f760

Browse files
committed
fix - add mla_k_merge kernel
1 parent 11919ac commit aa8f760

File tree

8 files changed

+342
-7
lines changed

8 files changed

+342
-7
lines changed

rtp_llm/cpp/kernels/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ cc_library(
5555
"//rtp_llm/cpp/cuda:cuda_utils_cu",
5656
"//rtp_llm/cpp/cuda:launch_utils",
5757
],
58-
copts = any_cuda_copts(),
58+
copts = any_cuda_copts(),
5959
include_prefix = "src",
6060
visibility = ["//visibility:public"],
6161
)
@@ -559,6 +559,7 @@ cc_library(
559559
]),
560560
hdrs = glob([
561561
"mla_kernels/*.h",
562+
"mla_kernels/*.cuh",
562563
]),
563564
deps = any_cuda_deps + [
564565
"//rtp_llm/cpp/utils:core_utils",

rtp_llm/cpp/kernels/mla_kernels/mla_merge_transpose_kernel.cu

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,4 +195,108 @@ INSTANTIATE_MLA_QKV_MERGE(__half);
195195
#ifdef ENABLE_BF16
196196
INSTANTIATE_MLA_QKV_MERGE(__nv_bfloat16);
197197
#endif
198+
199+
// adapter from sglang/sgl-kernel/csrc/elementwise/concat_mla.cu
200+
constexpr int NUM_LOCAL_HEADS = 128;
201+
constexpr int QK_NOPE_HEAD_DIM = 128;
202+
constexpr int QK_ROPE_HEAD_DIM = 64;
203+
constexpr int HEAD_CHUNK_SIZE = 16;
204+
constexpr int NUM_HEAD_CHUNKS = NUM_LOCAL_HEADS / HEAD_CHUNK_SIZE;
205+
// Fused kernel to concatenate k_nope and k_pe efficiently
206+
template<typename T>
207+
__global__ void concat_mla_k_kernel(T* __restrict__ k,
208+
const T* __restrict__ k_nope,
209+
const T* __restrict__ k_rope,
210+
const int num_tokens,
211+
const int64_t k_stride_0,
212+
const int k_stride_1,
213+
const int64_t k_nope_stride_0,
214+
const int k_nope_stride_1,
215+
const int64_t k_rope_stride_0) {
216+
const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
217+
const int token_id = flat_warp_id / NUM_HEAD_CHUNKS;
218+
const int head_chunk_id = flat_warp_id % NUM_HEAD_CHUNKS;
219+
const int lane_id = get_lane_id();
220+
if (token_id >= num_tokens)
221+
return;
222+
223+
using NopeVec = int2; // 8B/thread,32 thread = 256B/row
224+
using RopeVec = int; // 4B/thread,32 thread = 128B/row
225+
static_assert(sizeof(NopeVec) * 32 == QK_NOPE_HEAD_DIM * sizeof(nv_bfloat16), "nope vec mismatch");
226+
static_assert(sizeof(RopeVec) * 32 == QK_ROPE_HEAD_DIM * sizeof(nv_bfloat16), "rope vec mismatch");
227+
228+
const int head_row0 = head_chunk_id * HEAD_CHUNK_SIZE;
229+
230+
const int2* __restrict__ nope_src =
231+
reinterpret_cast<const int2*>(k_nope + token_id * k_nope_stride_0 + head_row0 * k_nope_stride_1) + lane_id;
232+
233+
int2* __restrict__ nope_dst = reinterpret_cast<int2*>(k + token_id * k_stride_0 + head_row0 * k_stride_1) + lane_id;
234+
235+
int* __restrict__ rope_dst =
236+
reinterpret_cast<int*>(k + token_id * k_stride_0 + head_row0 * k_stride_1 + QK_NOPE_HEAD_DIM) + lane_id;
237+
238+
const int nope_src_stride_v = (k_nope_stride_1 >> 2); // int2 covers 4 bf16
239+
const int nope_dst_stride_v = (k_stride_1 >> 2);
240+
const int rope_dst_stride_v = (k_stride_1 >> 1); // int covers 2 bf16
241+
242+
const int* rope_base = reinterpret_cast<const int*>(k_rope + token_id * k_rope_stride_0);
243+
const RopeVec rope_val = ld_na_global_v1(rope_base + lane_id);
244+
245+
prefetch_L2(nope_src);
246+
NopeVec cur = ld_na_global_v2(nope_src);
247+
248+
#pragma unroll
249+
for (int i = 0; i < HEAD_CHUNK_SIZE; ++i) {
250+
NopeVec next;
251+
if (i + 1 < HEAD_CHUNK_SIZE) {
252+
const int2* next_src = nope_src + nope_src_stride_v;
253+
prefetch_L2(next_src);
254+
next = ld_na_global_v2(next_src);
255+
}
256+
257+
st_na_global_v2(nope_dst, cur);
258+
st_na_global_v1(rope_dst, rope_val);
259+
260+
nope_src += nope_src_stride_v;
261+
nope_dst += nope_dst_stride_v;
262+
rope_dst += rope_dst_stride_v;
263+
264+
cur = next;
265+
}
266+
}
267+
268+
template<typename T>
269+
void invokeMlaKMerge(T* k,
270+
T* k_nope,
271+
T* k_rope,
272+
const int num_tokens,
273+
const int64_t k_stride_0,
274+
const int k_stride_1,
275+
const int64_t k_nope_stride_0,
276+
const int k_nope_stride_1,
277+
const int64_t k_rope_stride_0,
278+
cudaStream_t stream) {
279+
constexpr int num_warps_per_block = 32;
280+
const int grid_size = (num_tokens * NUM_HEAD_CHUNKS + num_warps_per_block - 1) / num_warps_per_block;
281+
const int block_size = num_warps_per_block * 32;
282+
283+
concat_mla_k_kernel<T><<<grid_size, block_size, 0, stream>>>(
284+
k, k_nope, k_rope, num_tokens, k_stride_0, k_stride_1, k_nope_stride_0, k_nope_stride_1, k_rope_stride_0);
285+
}
286+
287+
#define INSTANTIATE_MLA_K_MERGE(T) \
288+
template void invokeMlaKMerge<T>(T * k_out, \
289+
T * k_nope, \
290+
T * k_pe, \
291+
const int num_tokens, \
292+
const int64_t k_stride_0, \
293+
const int k_stride_1, \
294+
const int64_t k_nope_stride_0, \
295+
const int k_nope_stride_1, \
296+
const int64_t k_rope_stride_0, \
297+
cudaStream_t stream);
298+
299+
#ifdef ENABLE_BF16
300+
INSTANTIATE_MLA_K_MERGE(__nv_bfloat16);
301+
#endif
198302
} // namespace rtp_llm

rtp_llm/cpp/kernels/mla_kernels/mla_merge_transpose_kernel.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include <cuda_fp16.h>
99
#include <cuda_bf16.h>
1010
#endif
11+
#include <cstdint>
12+
#include "rtp_llm/cpp/kernels/mla_kernels/utils.cuh"
1113

1214
namespace rtp_llm {
1315

@@ -36,4 +38,17 @@ void invokeMlaQKVMerge(T* q,
3638
int rope_head_dim,
3739
int v_head_dim,
3840
cudaStream_t stream);
41+
42+
// Fused kernel to concatenate k_nope and k_pe in one operation
43+
template<typename T>
44+
void invokeMlaKMerge(T* k_out,
45+
T* k_nope,
46+
T* k_pe,
47+
const int num_tokens,
48+
const int64_t k_stride_0,
49+
const int k_stride_1,
50+
const int64_t k_nope_stride_0,
51+
const int k_nope_stride_1,
52+
const int64_t k_rope_stride_0,
53+
cudaStream_t stream);
3954
} // namespace rtp_llm
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Adapted from https://github.com/deepseek-ai/DeepEP/blob/main/csrc/kernels/utils.cuh
2+
3+
#pragma once
4+
5+
#include <cuda_bf16.h>
6+
#include <cuda_runtime.h>
7+
8+
#include <cstdint>
9+
10+
namespace rtp_llm {
11+
12+
__forceinline__ __device__ int get_lane_id() {
13+
int lane_id;
14+
asm("mov.s32 %0, %laneid;" : "=r"(lane_id));
15+
return lane_id;
16+
}
17+
18+
__device__ __forceinline__ void st_na_global_v1(const int* ptr, int v) {
19+
asm volatile("st.global.L1::no_allocate.s32 [%0], %1;" ::"l"(ptr), "r"(v) : "memory");
20+
}
21+
22+
__device__ __forceinline__ void st_na_global_v2(const int2* ptr, const int2& v) {
23+
asm volatile("st.global.L1::no_allocate.v2.s32 [%0], {%1, %2};" ::"l"(ptr), "r"(v.x), "r"(v.y) : "memory");
24+
}
25+
26+
__device__ __forceinline__ int ld_na_global_v1(const int* ptr) {
27+
int r;
28+
#ifdef USE_L2_HINT
29+
asm volatile("ld.global.nc.L1::no_allocate.L2::128B.s32 %0, [%1];" : "=r"(r) : "l"(ptr));
30+
#else
31+
asm volatile("ld.global.nc.L1::no_allocate.s32 %0, [%1];" : "=r"(r) : "l"(ptr));
32+
#endif
33+
return r;
34+
}
35+
36+
__device__ __forceinline__ int2 ld_na_global_v2(const int2* ptr) {
37+
int2 r;
38+
#ifdef USE_L2_HINT
39+
asm volatile("ld.global.nc.L1::no_allocate.L2::128B.v2.s32 {%0, %1}, [%2];" : "=r"(r.x), "=r"(r.y) : "l"(ptr));
40+
#else
41+
asm volatile("ld.global.nc.L1::no_allocate.v2.s32 {%0, %1}, [%2];" : "=r"(r.x), "=r"(r.y) : "l"(ptr));
42+
#endif
43+
return r;
44+
}
45+
46+
__device__ __forceinline__ void prefetch_L2(const void* p) {
47+
#if defined(ENABLE_L2_PREFETCH)
48+
asm volatile("prefetch.global.L2 [%0];" ::"l"(p));
49+
#endif
50+
}
51+
52+
} // namespace rtp_llm
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#include "rtp_llm/models_py/bindings/cuda/MlaKMergeOp.h"
2+
#include "rtp_llm/cpp/kernels/mla_kernels/mla_merge_transpose_kernel.h"
3+
#include "rtp_llm/cpp/core/torch_utils/BufferTorchUtils.h"
4+
#include "rtp_llm/models_py/bindings/common/Torch_ext.h"
5+
#include <cuda_runtime.h>
6+
7+
namespace rtp_llm {
8+
9+
void MlaKMerge(torch::Tensor k_out, torch::Tensor k_nope, torch::Tensor k_pe) {
10+
TORCH_CHECK(k_out.is_cuda(), "k_out must be on CUDA");
11+
TORCH_CHECK(k_nope.is_cuda(), "k_nope must be on CUDA");
12+
TORCH_CHECK(k_pe.is_cuda(), "k_pe must be on CUDA");
13+
14+
TORCH_CHECK(k_out.dim() == 3, "k_out must be 3D: [token_num, head_num, nope_head_dim + rope_head_dim]");
15+
TORCH_CHECK(k_nope.dim() == 3, "k_nope must be 3D: [token_num, head_num, nope_head_dim]");
16+
TORCH_CHECK(k_pe.dim() == 3, "k_pe must be 3D: [token_num, 1, rope_head_dim]");
17+
18+
StreamType stream = GET_CURRENT_STREAM();
19+
20+
const int num_tokens = k_out.size(0);
21+
const int64_t k_stride_0 = k_out.stride(0);
22+
const int k_stride_1 = k_out.stride(1);
23+
const int64_t k_nope_stride_0 = k_nope.stride(0);
24+
const int k_nope_stride_1 = k_nope.stride(1);
25+
const int64_t k_rope_stride_0 = k_pe.stride(0);
26+
27+
// Dispatch based on dtype
28+
if (k_out.dtype() == torch::kBFloat16) {
29+
invokeMlaKMerge<__nv_bfloat16>(reinterpret_cast<__nv_bfloat16*>(k_out.data_ptr()),
30+
reinterpret_cast<__nv_bfloat16*>(k_nope.data_ptr()),
31+
reinterpret_cast<__nv_bfloat16*>(k_pe.data_ptr()),
32+
num_tokens,
33+
k_stride_0,
34+
k_stride_1,
35+
k_nope_stride_0,
36+
k_nope_stride_1,
37+
k_rope_stride_0,
38+
stream);
39+
} else {
40+
TORCH_CHECK(false, "Unsupported dtype: ", k_out.dtype());
41+
}
42+
}
43+
44+
} // namespace rtp_llm
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#pragma once
2+
3+
#include <torch/extension.h>
4+
5+
namespace rtp_llm {
6+
7+
void MlaKMerge(torch::Tensor k_out, torch::Tensor k_nope, torch::Tensor k_pe);
8+
9+
} // namespace rtp_llm

rtp_llm/models_py/bindings/cuda/RegisterBaseBindings.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "3rdparty/flashinfer/flashinfer.h"
1616
#include "rtp_llm/models_py/bindings/cuda/TrtFp8QuantOp.h"
1717
#include "rtp_llm/models_py/bindings/cuda/ReuseKVCacheOp.h"
18+
#include "rtp_llm/models_py/bindings/cuda/MlaKMergeOp.h"
1819

1920
using namespace rtp_llm;
2021

@@ -145,6 +146,13 @@ void registerBasicCudaOps(py::module& rtp_ops_m) {
145146
py::arg("batch_reuse_info_vec"),
146147
py::arg("qo_indptr"),
147148
py::arg("tokens_per_block"));
149+
150+
rtp_ops_m.def("mla_k_merge",
151+
&rtp_llm::MlaKMerge,
152+
"Fused kernel to merge k_nope and k_pe efficiently",
153+
py::arg("k_out"),
154+
py::arg("k_nope"),
155+
py::arg("k_pe"));
148156
}
149157

150158
void registerBaseCudaBindings(py::module& rtp_ops_m) {

0 commit comments

Comments
 (0)