Skip to content

Commit de17a53

Browse files
committed
update
1 parent 36d5c8f commit de17a53

File tree

1 file changed

+38
-23
lines changed

1 file changed

+38
-23
lines changed

docs/developer-guide/kvcache.md

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# high-performance transformer inference with mha kv cache in ncnn
22

3-
This document details the implementation and usage of the key-value (kv) cache for the `MultiHeadAttention` layer in ncnn. This feature significantly accelerates autoregressive inference for Transformer-based models, such as large language models and other encoder-decoder architectures.
3+
This document details the implementation and usage of the key-value (kv) cache for the `MultiHeadAttention` and `SDPA` layer in ncnn. This feature significantly accelerates autoregressive inference for Transformer-based models, such as large language models and other encoder-decoder architectures.
44

55
## 1. what is kv cache?
66

@@ -20,9 +20,9 @@ Without optimization, the model must recompute the k and v matrices for all prec
2020
- **reduced computation:** It eliminates redundant calculations, saving significant computational resources and energy.
2121
- **enables real-time applications:** The performance gain makes it feasible to deploy large Transformer models for interactive and real-time tasks.
2222

23-
## 2. ncnn mha kv cache implementation
23+
## 2. ncnn kv cache implementation
2424

25-
ncnn introduces kv cache support directly into its `MultiHeadAttention` layer. The implementation is designed to be efficient and flexible, handling both the dynamic cache of self-attention and the static k/v of cross-attention found in encoder-decoder architectures.
25+
ncnn introduces kv cache support directly into its `MultiHeadAttention` and `SDPA` layer. The implementation is designed to be efficient and flexible, handling both the dynamic cache of self-attention and the static k/v of cross-attention found in encoder-decoder architectures.
2626

2727
### self-attention vs. cross-attention cache logic
2828

@@ -31,34 +31,49 @@ The caching strategy is fundamentally different for self-attention and cross-att
3131
#### self-attention (dynamic cache)
3232
- **purpose:** Allows the decoder to attend to previously generated tokens in its own sequence (e.g., the text being generated).
3333
- **cache Logic:** The cache is **dynamic** and grows with each generated token. In step `t`, the k and v for token `t` are computed and appended to the cache from step `t-1`.
34-
- **ncnn implementation:** The `MultiHeadAttention` layer for self-attention is modified to accept two additional inputs (`cache_k_in`, `cache_v_in`) and produce two corresponding outputs (`cache_k_out`, `cache_v_out`). The `7=1` parameter enables this dynamic caching behavior inside the layer.
34+
- **ncnn implementation:** The `MultiHeadAttention` and `SDPA` layer for self-attention is modified to accept two additional inputs (`cache_k_in`, `cache_v_in`) and produce two corresponding outputs (`cache_k_out`, `cache_v_out`). The `7=1` parameter enables this dynamic caching behavior inside the layer.
3535

3636
#### cross-attention (static k/v)
3737
- **purpose:** Allows the decoder to attend to the output of the encoder (e.g., attending to audio features in speech recognition, or an input sentence in translation).
3838
- **cache Logic:** The k and v matrices are derived from the encoder's output, which is computed only **once** per input sequence. Therefore, the k and v for cross-attention are **static** and do not change during the decoding process. They are "cached" in the sense that they are pre-computed and reused in every decoding step.
39-
- **ncnn implementation:** The `MultiHeadAttention` layer for cross-attention is also configured with `7=1` and cache I/O blobs. However, the implementation correctly identifies cross-attention (where the query blob is different from the key/value blobs) and reuses the `cache_k_in` and `cache_v_in` directly, without performing concatenation. This allows the static encoder k/v to be passed efficiently through the network.
39+
- **ncnn implementation:** The `MultiHeadAttention` and `SDPA` layer for cross-attention is also configured with `7=1` and cache I/O blobs. However, the implementation correctly identifies cross-attention (where the query blob is different from the key/value blobs) and reuses the `cache_k_in` and `cache_v_in` directly, without performing concatenation. This allows the static encoder k/v to be passed efficiently through the network.
4040

41-
## 3. ncnn mha kv cache memory layout
41+
## 3. ncnn kv cache memory layout
4242

43-
The memory layout of the kv cache is a critical design choice for performance. ncnn uses a **transposed layout** for the cache blobs. The primary reason for this is to **ensure that data for each attention head is contiguous in memory, which significantly boosts gemm performance.**
43+
The memory layout of the kv cache is a critical design choice for performance. ncnn uses different layouts for `MultiHeadAttention` and `SDPA` to optimize for their respective calculation patterns.
44+
45+
### `MultiHeadAttention` cache layout (Transposed)
46+
47+
The `MultiHeadAttention` layer uses a **transposed layout** for its cache blobs. The primary reason for this is to **ensure that data for each attention head is contiguous in memory, which significantly boosts gemm performance.**
4448

4549
* **input blobs (q, k, v):** These typically have a shape where height represents the sequence length.
4650
* `ncnn::Mat` dimensions: `(w = embed_dim, h = seq_len)`
4751

48-
* **cache blobs (e.g., `k_affine`, `v_affine`):** These are stored in a **transposed** format.
52+
* **cache blobs (`k_cache`, `v_cache`):** These are stored in a **transposed** format.
4953
* `ncnn::Mat` dimensions: `(w = seq_len, h = embed_dim)`
5054

5155
**the rationale:**
5256

53-
1. **slicing by Head:** During the attention calculation, the code slices the `k_affine` and `v_affine` matrices along their height to isolate the data for each head (e.g., using `row_range(head_index * embed_dim_per_head, embed_dim_per_head)`).
57+
1. **slicing by Head:** During the attention calculation, the code slices the `k_cache` and `v_cache` matrices along their height to isolate the data for each head (e.g., using `row_range(head_index * embed_dim_per_head, embed_dim_per_head)`).
5458
2. **memory contiguity:** Because `ncnn::Mat` uses a row-major memory layout, this slicing operation on the transposed cache blob results in a sub-matrix where all the data for a single head is perfectly contiguous.
5559
3. **gemm efficiency:** Subsequent matrix multiplication operations (`q * k^T` and `Attention * v`) can then operate on these contiguous memory blocks. This maximizes CPU cache locality and the effectiveness of simd instructions, leading to a substantial increase in computational speed.
5660

5761
If a non-transposed layout were used, the data for each head would be strided in memory, causing frequent cache misses and dramatically slowing down the performance-critical gemm calculations. Therefore, this transposed layout is a deliberate and crucial optimization for computation.
5862

63+
### `SDPA` cache layout (Standard)
64+
65+
The `SDPA` layer uses the **standard ncnn Mat layout**, where the sequence length is represented by the height.
66+
67+
* **input blobs (q, k, v):** `(w = embed_dim, h = seq_len, c = num_heads)`
68+
* **cache blobs (`k_cache`, `v_cache`):** `(w = embed_dim, h = seq_len, c = num_heads)`
69+
70+
**the rationale:**
71+
72+
The `SDPA` layer's internal implementation directly concatenates the cache blobs (`past_k`, `past_v`) with the current ones (`cur_k`, `cur_v`) along the height dimension (`seq_len`). This simpler approach avoids the need for a transposed layout while still being highly efficient, as the concatenation logic is handled inside the optimized C++ implementation.
73+
5974
## 4. converting models to support kv cache
6075

61-
To enable kv cache, you must modify the model's `.param` file to add the necessary cache inputs and outputs to all `MultiHeadAttention` layers in the decoder.
76+
To enable kv cache, you must modify the model's `.param` file to add the necessary cache inputs and outputs to all `MultiHeadAttention` and `SDPA` layers in the decoder.
6277

6378
### step 1: export a sequence-length-1 model
6479

@@ -68,9 +83,9 @@ First, export your model from its original framework (e.g., PyTorch) using a seq
6883

6984
After exporting, a script is needed to edit the generated `.ncnn.param` file to make it cache-aware.
7085

71-
#### A. Adding kv cache to All MultiHeadAttention Layers
86+
#### A. Adding kv cache to All MultiHeadAttention and SDPA Layers
7287

73-
You must add cache inputs/outputs to **every** `MultiHeadAttention` layer in the decoder.
88+
You must add cache inputs/outputs to **every** `MultiHeadAttention` / `SDPA` layer in the decoder.
7489

7590
- **change `input_count` and `output_count`:** Increase both by 2.
7691
- **add blob names:** Append new, unique blob names for `cache_k_in`, `cache_v_in`, `cache_k_out`, and `cache_v_out`.
@@ -81,7 +96,7 @@ Here is a robust Python function that automates this process:
8196
def add_kv_cache_to_ncnn_param(filename):
8297
"""
8398
Modifies an ncnn.param file to add a kv cache mechanism to all
84-
MultiHeadAttention layers and overwrites the original file.
99+
MultiHeadAttention and SDPA layers and overwrites the original file.
85100
This handles both self-attention and cross-attention layers.
86101
"""
87102
import os
@@ -98,15 +113,15 @@ def add_kv_cache_to_ncnn_param(filename):
98113
original_layer_count = int(header_parts[0])
99114
original_blob_count = int(header_parts[1])
100115

101-
mha_indices = [i for i, line in enumerate(lines) if line.strip().startswith("MultiHeadAttention")]
102-
mha_count = len(mha_indices)
116+
attention_indices = [i for i, line in enumerate(lines) if line.strip().startswith("MultiHeadAttention") or line.strip().startswith("SDPA")]
117+
attention_count = len(attention_indices)
103118

104-
if mha_count == 0:
105-
print("No 'MultiHeadAttention' layers found. The file will not be modified.")
119+
if attention_count == 0:
120+
print("No 'MultiHeadAttention' or 'SDPA' layers found. The file will not be modified.")
106121
return
107122

108-
# --- modify MultiHeadAttention layers ---
109-
for i, line_index in enumerate(mha_indices):
123+
# --- modify MultiHeadAttention and SDPA layers ---
124+
for i, line_index in enumerate(attention_indices):
110125
parts = lines[line_index].strip().split()
111126
layer_type, layer_name, input_count_str, output_count_str = parts[:4]
112127
input_count, output_count = int(input_count_str), int(output_count_str)
@@ -132,15 +147,15 @@ def add_kv_cache_to_ncnn_param(filename):
132147
new_layer_count = original_layer_count + 1
133148
# each mha needs 2 new *input* blobs and produces 2 new *output* blobs.
134149
# the total number of unique blobs increases by 4 for each mha.
135-
new_blob_count = original_blob_count + (mha_count * 4)
150+
new_blob_count = original_blob_count + (attention_count * 4)
136151
lines[header_line_index] = f"{new_layer_count} {new_blob_count}\n"
137152

138153
# find where to insert the new input layer (after existing ones)
139154
insert_pos = header_line_index + 1
140155
while insert_pos < len(lines) and lines[insert_pos].strip().startswith("Input"):
141156
insert_pos += 1
142157

143-
cache_blob_names = [name for i in range(mha_count) for name in (f"cache_k_in_{i}", f"cache_v_in_{i}")]
158+
cache_blob_names = [name for i in range(attention_count) for name in (f"cache_k_in_{i}", f"cache_v_in_{i}")]
144159
input_layer_line = (
145160
f"{'Input':<24} {'kv_cache_in':<24} 0 {len(cache_blob_names)} "
146161
f"{' '.join(cache_blob_names)}\n"
@@ -150,7 +165,7 @@ def add_kv_cache_to_ncnn_param(filename):
150165
with open(filename, 'w', encoding='utf-8') as f:
151166
f.writelines(lines)
152167

153-
print(f"Successfully added kv cache to {mha_count} MultiHeadAttention layers.")
168+
print(f"Successfully added kv cache to {attention_count} MultiHeadAttention / SDPA layers.")
154169

155170
# usage:
156171
# add_kv_cache_to_ncnn_param("your_model_decoder.ncnn.param")
@@ -206,7 +221,7 @@ void find_mha_kvcache_blobs(const ncnn::Net& net, kvcache_info& info)
206221
for (const ncnn::Layer* layer : net.layers())
207222
{
208223
// cache-enabled mha layer has 3 outputs (out, cache_k_out, cache_v_out) instead of 1
209-
if (layer->typeindex == ncnn::LayerType::MultiHeadAttention && layer->tops.size() == 3)
224+
if ((layer->typeindex == ncnn::LayerType::MultiHeadAttention || layer->typeindex == ncnn::LayerType::SDPA) && layer->tops.size() == 3)
210225
{
211226
// the script adds cache_k and cache_v as the last two inputs/outputs
212227
int input_count = layer->bottoms.size();

0 commit comments

Comments
 (0)