Skip to content

Commit 11fb997

Browse files
authored
sdpa kvcache (#6405)
1 parent 69652f4 commit 11fb997

File tree

8 files changed

+369
-58
lines changed

8 files changed

+369
-58
lines changed

docs/developer-guide/kvcache.md

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# high-performance transformer inference with mha kv cache in ncnn
22

3-
This document details the implementation and usage of the key-value (kv) cache for the `MultiHeadAttention` layer in ncnn. This feature significantly accelerates autoregressive inference for Transformer-based models, such as large language models and other encoder-decoder architectures.
3+
This document details the implementation and usage of the key-value (kv) cache for the `MultiHeadAttention` and `SDPA` layer in ncnn. This feature significantly accelerates autoregressive inference for Transformer-based models, such as large language models and other encoder-decoder architectures.
44

55
## 1. what is kv cache?
66

@@ -20,9 +20,9 @@ Without optimization, the model must recompute the k and v matrices for all prec
2020
- **reduced computation:** It eliminates redundant calculations, saving significant computational resources and energy.
2121
- **enables real-time applications:** The performance gain makes it feasible to deploy large Transformer models for interactive and real-time tasks.
2222

23-
## 2. ncnn mha kv cache implementation
23+
## 2. ncnn kv cache implementation
2424

25-
ncnn introduces kv cache support directly into its `MultiHeadAttention` layer. The implementation is designed to be efficient and flexible, handling both the dynamic cache of self-attention and the static k/v of cross-attention found in encoder-decoder architectures.
25+
ncnn introduces kv cache support directly into its `MultiHeadAttention` and `SDPA` layer. The implementation is designed to be efficient and flexible, handling both the dynamic cache of self-attention and the static k/v of cross-attention found in encoder-decoder architectures.
2626

2727
### self-attention vs. cross-attention cache logic
2828

@@ -31,34 +31,49 @@ The caching strategy is fundamentally different for self-attention and cross-att
3131
#### self-attention (dynamic cache)
3232
- **purpose:** Allows the decoder to attend to previously generated tokens in its own sequence (e.g., the text being generated).
3333
- **cache Logic:** The cache is **dynamic** and grows with each generated token. In step `t`, the k and v for token `t` are computed and appended to the cache from step `t-1`.
34-
- **ncnn implementation:** The `MultiHeadAttention` layer for self-attention is modified to accept two additional inputs (`cache_k_in`, `cache_v_in`) and produce two corresponding outputs (`cache_k_out`, `cache_v_out`). The `7=1` parameter enables this dynamic caching behavior inside the layer.
34+
- **ncnn implementation:** The `MultiHeadAttention` and `SDPA` layers for self-attention are modified to accept two additional inputs (`cache_k_in`, `cache_v_in`) and produce two corresponding outputs (`cache_k_out`, `cache_v_out`). The `7=1` parameter enables this dynamic caching behavior inside the layer.
3535

3636
#### cross-attention (static k/v)
3737
- **purpose:** Allows the decoder to attend to the output of the encoder (e.g., attending to audio features in speech recognition, or an input sentence in translation).
3838
- **cache Logic:** The k and v matrices are derived from the encoder's output, which is computed only **once** per input sequence. Therefore, the k and v for cross-attention are **static** and do not change during the decoding process. They are "cached" in the sense that they are pre-computed and reused in every decoding step.
39-
- **ncnn implementation:** The `MultiHeadAttention` layer for cross-attention is also configured with `7=1` and cache I/O blobs. However, the implementation correctly identifies cross-attention (where the query blob is different from the key/value blobs) and reuses the `cache_k_in` and `cache_v_in` directly, without performing concatenation. This allows the static encoder k/v to be passed efficiently through the network.
39+
- **ncnn implementation:** The `MultiHeadAttention` and `SDPA` layers for cross-attention are also configured with `7=1` and cache I/O blobs. However, the implementation correctly identifies cross-attention (where the query blob is different from the key/value blobs) and reuses the `cache_k_in` and `cache_v_in` directly, without performing concatenation. This allows the static encoder k/v to be passed efficiently through the network.
4040

41-
## 3. ncnn mha kv cache memory layout
41+
## 3. ncnn kv cache memory layout
4242

43-
The memory layout of the kv cache is a critical design choice for performance. ncnn uses a **transposed layout** for the cache blobs. The primary reason for this is to **ensure that data for each attention head is contiguous in memory, which significantly boosts gemm performance.**
43+
The memory layout of the kv cache is a critical design choice for performance. ncnn uses different layouts for `MultiHeadAttention` and `SDPA` to optimize for their respective calculation patterns.
44+
45+
### `MultiHeadAttention` cache layout (Transposed)
46+
47+
The `MultiHeadAttention` layer uses a **transposed layout** for its cache blobs. The primary reason for this is to **ensure that data for each attention head is contiguous in memory, which significantly boosts gemm performance.**
4448

4549
* **input blobs (q, k, v):** These typically have a shape where height represents the sequence length.
4650
* `ncnn::Mat` dimensions: `(w = embed_dim, h = seq_len)`
4751

48-
* **cache blobs (e.g., `k_affine`, `v_affine`):** These are stored in a **transposed** format.
52+
* **cache blobs (`k_cache`, `v_cache`):** These are stored in a **transposed** format.
4953
* `ncnn::Mat` dimensions: `(w = seq_len, h = embed_dim)`
5054

5155
**the rationale:**
5256

53-
1. **slicing by Head:** During the attention calculation, the code slices the `k_affine` and `v_affine` matrices along their height to isolate the data for each head (e.g., using `row_range(head_index * embed_dim_per_head, embed_dim_per_head)`).
57+
1. **slicing by Head:** During the attention calculation, the code slices the `k_cache` and `v_cache` matrices along their height to isolate the data for each head (e.g., using `row_range(head_index * embed_dim_per_head, embed_dim_per_head)`).
5458
2. **memory contiguity:** Because `ncnn::Mat` uses a row-major memory layout, this slicing operation on the transposed cache blob results in a sub-matrix where all the data for a single head is perfectly contiguous.
5559
3. **gemm efficiency:** Subsequent matrix multiplication operations (`q * k^T` and `Attention * v`) can then operate on these contiguous memory blocks. This maximizes CPU cache locality and the effectiveness of simd instructions, leading to a substantial increase in computational speed.
5660

5761
If a non-transposed layout were used, the data for each head would be strided in memory, causing frequent cache misses and dramatically slowing down the performance-critical gemm calculations. Therefore, this transposed layout is a deliberate and crucial optimization for computation.
5862

63+
### `SDPA` cache layout (Standard)
64+
65+
The `SDPA` layer uses the **standard ncnn Mat layout**, where the sequence length is represented by the height.
66+
67+
* **input blobs (q, k, v):** `(w = embed_dim, h = seq_len, c = num_heads)`
68+
* **cache blobs (`k_cache`, `v_cache`):** `(w = embed_dim, h = seq_len, c = num_heads)`
69+
70+
**the rationale:**
71+
72+
The `SDPA` layer's internal implementation directly concatenates the cache blobs (`past_k`, `past_v`) with the current ones (`cur_k`, `cur_v`) along the height dimension (`seq_len`). This simpler approach avoids the need for a transposed layout while still being highly efficient, as the concatenation logic is handled inside the optimized C++ implementation.
73+
5974
## 4. converting models to support kv cache
6075

61-
To enable kv cache, you must modify the model's `.param` file to add the necessary cache inputs and outputs to all `MultiHeadAttention` layers in the decoder.
76+
To enable kv cache, you must modify the model's `.param` file to add the necessary cache inputs and outputs to all `MultiHeadAttention` and `SDPA` layers in the decoder.
6277

6378
### step 1: export a sequence-length-1 model
6479

@@ -68,9 +83,9 @@ First, export your model from its original framework (e.g., PyTorch) using a seq
6883

6984
After exporting, a script is needed to edit the generated `.ncnn.param` file to make it cache-aware.
7085

71-
#### A. Adding kv cache to All MultiHeadAttention Layers
86+
#### A. Adding kv cache to All MultiHeadAttention and SDPA Layers
7287

73-
You must add cache inputs/outputs to **every** `MultiHeadAttention` layer in the decoder.
88+
You must add cache inputs/outputs to **every** `MultiHeadAttention` / `SDPA` layer in the decoder.
7489

7590
- **change `input_count` and `output_count`:** Increase both by 2.
7691
- **add blob names:** Append new, unique blob names for `cache_k_in`, `cache_v_in`, `cache_k_out`, and `cache_v_out`.
@@ -81,7 +96,7 @@ Here is a robust Python function that automates this process:
8196
def add_kv_cache_to_ncnn_param(filename):
8297
"""
8398
Modifies an ncnn.param file to add a kv cache mechanism to all
84-
MultiHeadAttention layers and overwrites the original file.
99+
MultiHeadAttention and SDPA layers and overwrites the original file.
85100
This handles both self-attention and cross-attention layers.
86101
"""
87102
import os
@@ -98,15 +113,15 @@ def add_kv_cache_to_ncnn_param(filename):
98113
original_layer_count = int(header_parts[0])
99114
original_blob_count = int(header_parts[1])
100115

101-
mha_indices = [i for i, line in enumerate(lines) if line.strip().startswith("MultiHeadAttention")]
102-
mha_count = len(mha_indices)
116+
attention_indices = [i for i, line in enumerate(lines) if line.strip().startswith("MultiHeadAttention") or line.strip().startswith("SDPA")]
117+
attention_count = len(attention_indices)
103118

104-
if mha_count == 0:
105-
print("No 'MultiHeadAttention' layers found. The file will not be modified.")
119+
if attention_count == 0:
120+
print("No 'MultiHeadAttention' or 'SDPA' layers found. The file will not be modified.")
106121
return
107122

108-
# --- modify MultiHeadAttention layers ---
109-
for i, line_index in enumerate(mha_indices):
123+
# --- modify MultiHeadAttention and SDPA layers ---
124+
for i, line_index in enumerate(attention_indices):
110125
parts = lines[line_index].strip().split()
111126
layer_type, layer_name, input_count_str, output_count_str = parts[:4]
112127
input_count, output_count = int(input_count_str), int(output_count_str)
@@ -132,15 +147,15 @@ def add_kv_cache_to_ncnn_param(filename):
132147
new_layer_count = original_layer_count + 1
133148
# each mha needs 2 new *input* blobs and produces 2 new *output* blobs.
134149
# the total number of unique blobs increases by 4 for each mha.
135-
new_blob_count = original_blob_count + (mha_count * 4)
150+
new_blob_count = original_blob_count + (attention_count * 4)
136151
lines[header_line_index] = f"{new_layer_count} {new_blob_count}\n"
137152

138153
# find where to insert the new input layer (after existing ones)
139154
insert_pos = header_line_index + 1
140155
while insert_pos < len(lines) and lines[insert_pos].strip().startswith("Input"):
141156
insert_pos += 1
142157

143-
cache_blob_names = [name for i in range(mha_count) for name in (f"cache_k_in_{i}", f"cache_v_in_{i}")]
158+
cache_blob_names = [name for i in range(attention_count) for name in (f"cache_k_in_{i}", f"cache_v_in_{i}")]
144159
input_layer_line = (
145160
f"{'Input':<24} {'kv_cache_in':<24} 0 {len(cache_blob_names)} "
146161
f"{' '.join(cache_blob_names)}\n"
@@ -150,7 +165,7 @@ def add_kv_cache_to_ncnn_param(filename):
150165
with open(filename, 'w', encoding='utf-8') as f:
151166
f.writelines(lines)
152167

153-
print(f"Successfully added kv cache to {mha_count} MultiHeadAttention layers.")
168+
print(f"Successfully added kv cache to {attention_count} MultiHeadAttention / SDPA layers.")
154169

155170
# usage:
156171
# add_kv_cache_to_ncnn_param("your_model_decoder.ncnn.param")
@@ -206,7 +221,7 @@ void find_mha_kvcache_blobs(const ncnn::Net& net, kvcache_info& info)
206221
for (const ncnn::Layer* layer : net.layers())
207222
{
208223
// cache-enabled mha layer has 3 outputs (out, cache_k_out, cache_v_out) instead of 1
209-
if (layer->typeindex == ncnn::LayerType::MultiHeadAttention && layer->tops.size() == 3)
224+
if ((layer->typeindex == ncnn::LayerType::MultiHeadAttention || layer->typeindex == ncnn::LayerType::SDPA) && layer->tops.size() == 3)
210225
{
211226
// the script adds cache_k and cache_v as the last two inputs/outputs
212227
int input_count = layer->bottoms.size();

docs/developer-guide/operators.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1811,6 +1811,7 @@ for each num_head part
18111811
| --------- | ------------- | ----- | --------- | ----------------- |
18121812
| 5 | attn_mask | int | 0 | |
18131813
| 6 | scale | float | 0.f | auto = 1.f / sqrt(embed_dim) |
1814+
| 7 | kv_cache | int | 0 | |
18141815
| 18 | int8_scale_term | int | 0 | |
18151816

18161817
# SELU

src/layer/sdpa.cpp

Lines changed: 123 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ int SDPA::load_param(const ParamDict& pd)
1717
{
1818
attn_mask = pd.get(5, 0);
1919
scale = pd.get(6, 0.f);
20+
kv_cache = pd.get(7, 0);
2021
int8_scale_term = pd.get(18, 0);
2122

2223
return 0;
@@ -33,20 +34,24 @@ int SDPA::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
3334
#endif
3435

3536
const Mat& query = bottom_blobs[0];
36-
const Mat& key = bottom_blobs[1];
37-
const Mat& value = bottom_blobs[2];
38-
const Mat& attn_mask_blob = bottom_blobs.size() == 4 ? bottom_blobs[3] : Mat();
37+
const Mat& cur_key = bottom_blobs[1];
38+
const Mat& cur_value = bottom_blobs[2];
39+
const Mat& attn_mask_blob = attn_mask ? bottom_blobs[3] : Mat();
40+
const Mat& past_key = kv_cache ? bottom_blobs[attn_mask ? 4 : 3] : Mat();
41+
const Mat& past_value = kv_cache ? bottom_blobs[attn_mask ? 5 : 4] : Mat();
3942

4043
const int embed_dim = query.w;
4144
const int src_seqlen = query.h;
4245
const int num_heads = query.c;
43-
const int dst_seqlen = key.h;
44-
const int num_group = key.c;
45-
const int out_embed_dim = value.w;
46-
47-
// assert key.w == embed_dim
48-
// assert key.h == value.h == dst_seqlen
49-
// assert value.c == num_group
46+
const int cur_seqlen = cur_key.h;
47+
const int num_group = cur_key.c;
48+
const int out_embed_dim = cur_value.w;
49+
const int past_seqlen = kv_cache ? past_key.h : 0;
50+
const int dst_seqlen = past_seqlen + cur_seqlen;
51+
52+
// assert cur_key.w == embed_dim
53+
// assert cur_key.h == cur_value.h == cur_seqlen
54+
// assert cur_value.c == num_group
5055
// assert num_heads % num_group == 0
5156

5257
const float _scale = scale == 0.f ? 1.f / sqrt(embed_dim) : scale;
@@ -61,6 +66,46 @@ int SDPA::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
6166
if (qk_cross.empty())
6267
return -100;
6368

69+
Mat key = cur_key;
70+
if (past_seqlen > 0)
71+
{
72+
key.create(embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator);
73+
if (key.empty())
74+
return -100;
75+
76+
// concat
77+
#pragma omp parallel for num_threads(opt.num_threads)
78+
for (int q = 0; q < num_group; q++)
79+
{
80+
const Mat past_key_head = past_key.channel(q);
81+
const Mat cur_key_head = cur_key.channel(q);
82+
Mat key_head = key.channel(q);
83+
84+
memcpy(key_head.row(0), past_key_head, embed_dim * past_seqlen * sizeof(float));
85+
memcpy(key_head.row(past_seqlen), cur_key_head, embed_dim * cur_seqlen * sizeof(float));
86+
}
87+
}
88+
89+
Mat value = cur_value;
90+
if (past_seqlen > 0)
91+
{
92+
value.create(out_embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator);
93+
if (value.empty())
94+
return -100;
95+
96+
// concat
97+
#pragma omp parallel for num_threads(opt.num_threads)
98+
for (int q = 0; q < num_group; q++)
99+
{
100+
const Mat past_value_head = past_value.channel(q);
101+
const Mat cur_value_head = cur_value.channel(q);
102+
Mat value_head = value.channel(q);
103+
104+
memcpy(value_head.row(0), past_value_head, out_embed_dim * past_seqlen * sizeof(float));
105+
memcpy(value_head.row(past_seqlen), cur_value_head, out_embed_dim * cur_seqlen * sizeof(float));
106+
}
107+
}
108+
64109
#pragma omp parallel for num_threads(opt.num_threads)
65110
for (int q = 0; q < num_heads; q++)
66111
{
@@ -153,6 +198,13 @@ int SDPA::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
153198
}
154199
}
155200

201+
if (kv_cache)
202+
{
203+
// assert top_blobs.size() == 3
204+
top_blobs[1] = key;
205+
top_blobs[2] = value;
206+
}
207+
156208
return 0;
157209
}
158210

@@ -223,20 +275,24 @@ static void dynamic_quantize_2d_per_h(const Mat& blob, Mat& blob_int8, Mat& scal
223275
int SDPA::forward_int8(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
224276
{
225277
const Mat& query = bottom_blobs[0];
226-
const Mat& key = bottom_blobs[1];
227-
const Mat& value = bottom_blobs[2];
228-
const Mat& attn_mask_blob = bottom_blobs.size() == 4 ? bottom_blobs[3] : Mat();
278+
const Mat& cur_key = bottom_blobs[1];
279+
const Mat& cur_value = bottom_blobs[2];
280+
const Mat& attn_mask_blob = attn_mask ? bottom_blobs[3] : Mat();
281+
const Mat& past_key = kv_cache ? bottom_blobs[attn_mask ? 4 : 3] : Mat();
282+
const Mat& past_value = kv_cache ? bottom_blobs[attn_mask ? 5 : 4] : Mat();
229283

230284
const int embed_dim = query.w;
231285
const int src_seqlen = query.h;
232286
const int num_heads = query.c;
233-
const int dst_seqlen = key.h;
234-
const int num_group = key.c;
235-
const int out_embed_dim = value.w;
236-
237-
// assert key.w == embed_dim
238-
// assert key.h == value.h == dst_seqlen
239-
// assert value.c == num_group
287+
const int cur_seqlen = cur_key.h;
288+
const int num_group = cur_key.c;
289+
const int out_embed_dim = cur_value.w;
290+
const int past_seqlen = kv_cache ? past_key.h : 0;
291+
const int dst_seqlen = past_seqlen + cur_seqlen;
292+
293+
// assert cur_key.w == embed_dim
294+
// assert cur_key.h == cur_value.h == cur_seqlen
295+
// assert cur_value.c == num_group
240296
// assert num_heads % num_group == 0
241297

242298
const float _scale = scale == 0.f ? 1.f / sqrt(embed_dim) : scale;
@@ -271,6 +327,46 @@ int SDPA::forward_int8(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& t
271327
if (query_or_qk_cross_int8_scales.empty())
272328
return -100;
273329

330+
Mat key = cur_key;
331+
if (past_seqlen > 0)
332+
{
333+
key.create(embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator);
334+
if (key.empty())
335+
return -100;
336+
337+
// concat
338+
#pragma omp parallel for num_threads(opt.num_threads)
339+
for (int q = 0; q < num_group; q++)
340+
{
341+
const Mat past_key_head = past_key.channel(q);
342+
const Mat cur_key_head = cur_key.channel(q);
343+
Mat key_head = key.channel(q);
344+
345+
memcpy(key_head.row(0), past_key_head, embed_dim * past_seqlen * sizeof(float));
346+
memcpy(key_head.row(past_seqlen), cur_key_head, embed_dim * cur_seqlen * sizeof(float));
347+
}
348+
}
349+
350+
Mat value = cur_value;
351+
if (past_seqlen > 0)
352+
{
353+
value.create(out_embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator);
354+
if (value.empty())
355+
return -100;
356+
357+
// concat
358+
#pragma omp parallel for num_threads(opt.num_threads)
359+
for (int q = 0; q < num_group; q++)
360+
{
361+
const Mat past_value_head = past_value.channel(q);
362+
const Mat cur_value_head = cur_value.channel(q);
363+
Mat value_head = value.channel(q);
364+
365+
memcpy(value_head.row(0), past_value_head, out_embed_dim * past_seqlen * sizeof(float));
366+
memcpy(value_head.row(past_seqlen), cur_value_head, out_embed_dim * cur_seqlen * sizeof(float));
367+
}
368+
}
369+
274370
#pragma omp parallel for num_threads(opt.num_threads)
275371
for (int q = 0; q < num_heads; q++)
276372
{
@@ -389,6 +485,13 @@ int SDPA::forward_int8(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& t
389485
}
390486
}
391487

488+
if (kv_cache)
489+
{
490+
// assert top_blobs.size() == 3
491+
top_blobs[1] = key;
492+
top_blobs[2] = value;
493+
}
494+
392495
return 0;
393496
}
394497
#endif // NCNN_INT8

src/layer/sdpa.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class SDPA : public Layer
2525
public:
2626
int attn_mask;
2727
float scale;
28+
int kv_cache;
2829

2930
int int8_scale_term;
3031
};

0 commit comments

Comments
 (0)