You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/developer-guide/kvcache.md
+38-23Lines changed: 38 additions & 23 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,6 +1,6 @@
1
1
# high-performance transformer inference with mha kv cache in ncnn
2
2
3
-
This document details the implementation and usage of the key-value (kv) cache for the `MultiHeadAttention` layer in ncnn. This feature significantly accelerates autoregressive inference for Transformer-based models, such as large language models and other encoder-decoder architectures.
3
+
This document details the implementation and usage of the key-value (kv) cache for the `MultiHeadAttention`and `SDPA`layer in ncnn. This feature significantly accelerates autoregressive inference for Transformer-based models, such as large language models and other encoder-decoder architectures.
4
4
5
5
## 1. what is kv cache?
6
6
@@ -20,9 +20,9 @@ Without optimization, the model must recompute the k and v matrices for all prec
20
20
-**reduced computation:** It eliminates redundant calculations, saving significant computational resources and energy.
21
21
-**enables real-time applications:** The performance gain makes it feasible to deploy large Transformer models for interactive and real-time tasks.
22
22
23
-
## 2. ncnn mha kv cache implementation
23
+
## 2. ncnn kv cache implementation
24
24
25
-
ncnn introduces kv cache support directly into its `MultiHeadAttention` layer. The implementation is designed to be efficient and flexible, handling both the dynamic cache of self-attention and the static k/v of cross-attention found in encoder-decoder architectures.
25
+
ncnn introduces kv cache support directly into its `MultiHeadAttention`and `SDPA`layer. The implementation is designed to be efficient and flexible, handling both the dynamic cache of self-attention and the static k/v of cross-attention found in encoder-decoder architectures.
26
26
27
27
### self-attention vs. cross-attention cache logic
28
28
@@ -31,34 +31,49 @@ The caching strategy is fundamentally different for self-attention and cross-att
31
31
#### self-attention (dynamic cache)
32
32
-**purpose:** Allows the decoder to attend to previously generated tokens in its own sequence (e.g., the text being generated).
33
33
-**cache Logic:** The cache is **dynamic** and grows with each generated token. In step `t`, the k and v for token `t` are computed and appended to the cache from step `t-1`.
34
-
-**ncnn implementation:** The `MultiHeadAttention` layer for self-attention is modified to accept two additional inputs (`cache_k_in`, `cache_v_in`) and produce two corresponding outputs (`cache_k_out`, `cache_v_out`). The `7=1` parameter enables this dynamic caching behavior inside the layer.
34
+
-**ncnn implementation:** The `MultiHeadAttention`and `SDPA`layer for self-attention is modified to accept two additional inputs (`cache_k_in`, `cache_v_in`) and produce two corresponding outputs (`cache_k_out`, `cache_v_out`). The `7=1` parameter enables this dynamic caching behavior inside the layer.
35
35
36
36
#### cross-attention (static k/v)
37
37
-**purpose:** Allows the decoder to attend to the output of the encoder (e.g., attending to audio features in speech recognition, or an input sentence in translation).
38
38
-**cache Logic:** The k and v matrices are derived from the encoder's output, which is computed only **once** per input sequence. Therefore, the k and v for cross-attention are **static** and do not change during the decoding process. They are "cached" in the sense that they are pre-computed and reused in every decoding step.
39
-
-**ncnn implementation:** The `MultiHeadAttention` layer for cross-attention is also configured with `7=1` and cache I/O blobs. However, the implementation correctly identifies cross-attention (where the query blob is different from the key/value blobs) and reuses the `cache_k_in` and `cache_v_in` directly, without performing concatenation. This allows the static encoder k/v to be passed efficiently through the network.
39
+
-**ncnn implementation:** The `MultiHeadAttention`and `SDPA`layer for cross-attention is also configured with `7=1` and cache I/O blobs. However, the implementation correctly identifies cross-attention (where the query blob is different from the key/value blobs) and reuses the `cache_k_in` and `cache_v_in` directly, without performing concatenation. This allows the static encoder k/v to be passed efficiently through the network.
40
40
41
-
## 3. ncnn mha kv cache memory layout
41
+
## 3. ncnn kv cache memory layout
42
42
43
-
The memory layout of the kv cache is a critical design choice for performance. ncnn uses a **transposed layout** for the cache blobs. The primary reason for this is to **ensure that data for each attention head is contiguous in memory, which significantly boosts gemm performance.**
43
+
The memory layout of the kv cache is a critical design choice for performance. ncnn uses different layouts for `MultiHeadAttention` and `SDPA` to optimize for their respective calculation patterns.
The `MultiHeadAttention` layer uses a **transposed layout** for its cache blobs. The primary reason for this is to **ensure that data for each attention head is contiguous in memory, which significantly boosts gemm performance.**
44
48
45
49
***input blobs (q, k, v):** These typically have a shape where height represents the sequence length.
46
50
*`ncnn::Mat` dimensions: `(w = embed_dim, h = seq_len)`
47
51
48
-
***cache blobs (e.g., `k_affine`, `v_affine`):** These are stored in a **transposed** format.
52
+
***cache blobs (`k_cache`, `v_cache`):** These are stored in a **transposed** format.
49
53
*`ncnn::Mat` dimensions: `(w = seq_len, h = embed_dim)`
50
54
51
55
**the rationale:**
52
56
53
-
1.**slicing by Head:** During the attention calculation, the code slices the `k_affine` and `v_affine` matrices along their height to isolate the data for each head (e.g., using `row_range(head_index * embed_dim_per_head, embed_dim_per_head)`).
57
+
1.**slicing by Head:** During the attention calculation, the code slices the `k_cache` and `v_cache` matrices along their height to isolate the data for each head (e.g., using `row_range(head_index * embed_dim_per_head, embed_dim_per_head)`).
54
58
2.**memory contiguity:** Because `ncnn::Mat` uses a row-major memory layout, this slicing operation on the transposed cache blob results in a sub-matrix where all the data for a single head is perfectly contiguous.
55
59
3.**gemm efficiency:** Subsequent matrix multiplication operations (`q * k^T` and `Attention * v`) can then operate on these contiguous memory blocks. This maximizes CPU cache locality and the effectiveness of simd instructions, leading to a substantial increase in computational speed.
56
60
57
61
If a non-transposed layout were used, the data for each head would be strided in memory, causing frequent cache misses and dramatically slowing down the performance-critical gemm calculations. Therefore, this transposed layout is a deliberate and crucial optimization for computation.
58
62
63
+
### `SDPA` cache layout (Standard)
64
+
65
+
The `SDPA` layer uses the **standard ncnn Mat layout**, where the sequence length is represented by the height.
66
+
67
+
***input blobs (q, k, v):**`(w = embed_dim, h = seq_len, c = num_heads)`
68
+
***cache blobs (`k_cache`, `v_cache`):**`(w = embed_dim, h = seq_len, c = num_heads)`
69
+
70
+
**the rationale:**
71
+
72
+
The `SDPA` layer's internal implementation directly concatenates the cache blobs (`past_k`, `past_v`) with the current ones (`cur_k`, `cur_v`) along the height dimension (`seq_len`). This simpler approach avoids the need for a transposed layout while still being highly efficient, as the concatenation logic is handled inside the optimized C++ implementation.
73
+
59
74
## 4. converting models to support kv cache
60
75
61
-
To enable kv cache, you must modify the model's `.param` file to add the necessary cache inputs and outputs to all `MultiHeadAttention` layers in the decoder.
76
+
To enable kv cache, you must modify the model's `.param` file to add the necessary cache inputs and outputs to all `MultiHeadAttention`and `SDPA`layers in the decoder.
62
77
63
78
### step 1: export a sequence-length-1 model
64
79
@@ -68,9 +83,9 @@ First, export your model from its original framework (e.g., PyTorch) using a seq
68
83
69
84
After exporting, a script is needed to edit the generated `.ncnn.param` file to make it cache-aware.
70
85
71
-
#### A. Adding kv cache to All MultiHeadAttention Layers
86
+
#### A. Adding kv cache to All MultiHeadAttention and SDPA Layers
72
87
73
-
You must add cache inputs/outputs to **every**`MultiHeadAttention` layer in the decoder.
88
+
You must add cache inputs/outputs to **every**`MultiHeadAttention`/ `SDPA`layer in the decoder.
74
89
75
90
-**change `input_count` and `output_count`:** Increase both by 2.
76
91
-**add blob names:** Append new, unique blob names for `cache_k_in`, `cache_v_in`, `cache_k_out`, and `cache_v_out`.
@@ -81,7 +96,7 @@ Here is a robust Python function that automates this process:
81
96
defadd_kv_cache_to_ncnn_param(filename):
82
97
"""
83
98
Modifies an ncnn.param file to add a kv cache mechanism to all
84
-
MultiHeadAttention layers and overwrites the original file.
99
+
MultiHeadAttention and SDPA layers and overwrites the original file.
85
100
This handles both self-attention and cross-attention layers.
0 commit comments