|
| 1 | +// Copyright 2025 Tencent |
| 2 | +// SPDX-License-Identifier: BSD-3-Clause |
| 3 | + |
| 4 | +#include "sdpa_x86.h" |
| 5 | + |
| 6 | +#include "layer_type.h" |
| 7 | + |
| 8 | +namespace ncnn { |
| 9 | + |
| 10 | +SDPA_x86::SDPA_x86() |
| 11 | +{ |
| 12 | + qk_gemm = 0; |
| 13 | + qkv_gemm = 0; |
| 14 | + qk_softmax = 0; |
| 15 | +} |
| 16 | + |
| 17 | +int SDPA_x86::create_pipeline(const Option& _opt) |
| 18 | +{ |
| 19 | + Option opt = _opt; |
| 20 | + if (int8_scale_term) |
| 21 | + { |
| 22 | + opt.use_packing_layout = false; // TODO enable packing |
| 23 | + } |
| 24 | + |
| 25 | + { |
| 26 | + qk_softmax = ncnn::create_layer_cpu(ncnn::LayerType::Softmax); |
| 27 | + ncnn::ParamDict pd; |
| 28 | + pd.set(0, -1); // axis |
| 29 | + pd.set(1, 1); |
| 30 | + qk_softmax->load_param(pd); |
| 31 | + qk_softmax->load_model(ModelBinFromMatArray(0)); |
| 32 | + qk_softmax->create_pipeline(opt); |
| 33 | + } |
| 34 | + |
| 35 | + // Q * K^T |
| 36 | + if (scale != 0.f) |
| 37 | + { |
| 38 | + qk_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); |
| 39 | + ncnn::ParamDict pd; |
| 40 | + |
| 41 | + pd.set(0, scale); // alpha |
| 42 | + pd.set(1, 1.f / scale); // beta |
| 43 | + pd.set(2, 0); // transA (Q: Seq x Embed) |
| 44 | + pd.set(3, 1); // transB (K: Seq x Embed -> K^T: Embed x Seq) => Q * K^T |
| 45 | + pd.set(4, 0); // constantA |
| 46 | + pd.set(5, 0); // constantB |
| 47 | + pd.set(6, attn_mask ? 0 : 1); // constantC (if mask exists, use it) |
| 48 | + pd.set(7, 0); // M |
| 49 | + pd.set(8, 0); // N |
| 50 | + pd.set(9, 0); // K |
| 51 | + pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C (MxN) |
| 52 | + pd.set(11, 0); // output_N1M |
| 53 | + pd.set(12, 1); // output_elempack |
| 54 | +#if NCNN_INT8 |
| 55 | + pd.set(18, int8_scale_term); |
| 56 | +#endif |
| 57 | + qk_gemm->load_param(pd); |
| 58 | + qk_gemm->load_model(ModelBinFromMatArray(0)); |
| 59 | + Option opt1 = opt; |
| 60 | + opt1.num_threads = 1; |
| 61 | + qk_gemm->create_pipeline(opt1); |
| 62 | + } |
| 63 | + |
| 64 | + // Attn * V |
| 65 | + { |
| 66 | + qkv_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); |
| 67 | + ncnn::ParamDict pd; |
| 68 | + pd.set(0, 1.f); // alpha |
| 69 | + pd.set(1, 1.f); // beta |
| 70 | + pd.set(2, 0); // transA (Attn: Seq x Seq) |
| 71 | + pd.set(3, 0); // transB (V: Seq x Embed) => Attn * V |
| 72 | + pd.set(4, 0); // constantA |
| 73 | + pd.set(5, 0); // constantB |
| 74 | + pd.set(6, 1); // constantC (None) |
| 75 | + pd.set(7, 0); // M |
| 76 | + pd.set(8, 0); // N |
| 77 | + pd.set(9, 0); // K |
| 78 | + pd.set(10, -1); // constant_broadcast_type_C |
| 79 | + pd.set(11, 0); // output_N1M |
| 80 | + pd.set(12, 1); // output_elempack |
| 81 | + pd.set(14, 0); // output_transpose |
| 82 | +#if NCNN_INT8 |
| 83 | + pd.set(18, int8_scale_term); |
| 84 | +#endif |
| 85 | + qkv_gemm->load_param(pd); |
| 86 | + qkv_gemm->load_model(ModelBinFromMatArray(0)); |
| 87 | + Option opt1 = opt; |
| 88 | + opt1.num_threads = 1; |
| 89 | + qkv_gemm->create_pipeline(opt1); |
| 90 | + } |
| 91 | + |
| 92 | + return 0; |
| 93 | +} |
| 94 | + |
| 95 | +int SDPA_x86::destroy_pipeline(const Option& _opt) |
| 96 | +{ |
| 97 | + Option opt = _opt; |
| 98 | + if (int8_scale_term) |
| 99 | + { |
| 100 | + opt.use_packing_layout = false; // TODO enable packing |
| 101 | + } |
| 102 | + |
| 103 | + if (qk_softmax) |
| 104 | + { |
| 105 | + qk_softmax->destroy_pipeline(opt); |
| 106 | + delete qk_softmax; |
| 107 | + qk_softmax = 0; |
| 108 | + } |
| 109 | + |
| 110 | + if (qk_gemm) |
| 111 | + { |
| 112 | + qk_gemm->destroy_pipeline(opt); |
| 113 | + delete qk_gemm; |
| 114 | + qk_gemm = 0; |
| 115 | + } |
| 116 | + |
| 117 | + if (qkv_gemm) |
| 118 | + { |
| 119 | + qkv_gemm->destroy_pipeline(opt); |
| 120 | + delete qkv_gemm; |
| 121 | + qkv_gemm = 0; |
| 122 | + } |
| 123 | + |
| 124 | + return 0; |
| 125 | +} |
| 126 | + |
| 127 | +int SDPA_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& _opt) const |
| 128 | +{ |
| 129 | + Option opt = _opt; |
| 130 | + if (int8_scale_term) |
| 131 | + { |
| 132 | + opt.use_packing_layout = false; // TODO enable packing |
| 133 | + } |
| 134 | + |
| 135 | + const Mat& query = bottom_blobs[0]; |
| 136 | + const Mat& cur_key = bottom_blobs[1]; |
| 137 | + const Mat& cur_value = bottom_blobs[2]; |
| 138 | + const Mat& attn_mask_blob = attn_mask ? bottom_blobs[3] : Mat(); |
| 139 | + const Mat& past_key = kv_cache ? bottom_blobs[attn_mask ? 4 : 3] : Mat(); |
| 140 | + const Mat& past_value = kv_cache ? bottom_blobs[attn_mask ? 5 : 4] : Mat(); |
| 141 | + |
| 142 | + const int embed_dim = query.w; |
| 143 | + const int src_seqlen = query.h; |
| 144 | + const int num_heads = query.c; |
| 145 | + const int cur_seqlen = cur_key.h; |
| 146 | + const int num_group = cur_key.c; |
| 147 | + const int out_embed_dim = cur_value.w; |
| 148 | + const int past_seqlen = kv_cache ? past_key.h : 0; |
| 149 | + const int dst_seqlen = past_seqlen + cur_seqlen; |
| 150 | + |
| 151 | + Mat key; |
| 152 | + if (past_seqlen > 0) |
| 153 | + { |
| 154 | + key.create(embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator); |
| 155 | + if (key.empty()) |
| 156 | + return -100; |
| 157 | + |
| 158 | + #pragma omp parallel for num_threads(opt.num_threads) |
| 159 | + for (int q = 0; q < num_group; q++) |
| 160 | + { |
| 161 | + const Mat past_key_head = past_key.channel(q); |
| 162 | + const Mat cur_key_head = cur_key.channel(q); |
| 163 | + Mat key_head = key.channel(q); |
| 164 | + |
| 165 | + memcpy(key_head.row(0), past_key_head, embed_dim * past_seqlen * sizeof(float)); |
| 166 | + memcpy(key_head.row(past_seqlen), cur_key_head, embed_dim * cur_seqlen * sizeof(float)); |
| 167 | + } |
| 168 | + } |
| 169 | + else |
| 170 | + { |
| 171 | + key = cur_key; |
| 172 | + } |
| 173 | + |
| 174 | + Mat value; |
| 175 | + if (past_seqlen > 0) |
| 176 | + { |
| 177 | + value.create(out_embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator); |
| 178 | + if (value.empty()) |
| 179 | + return -100; |
| 180 | + |
| 181 | + #pragma omp parallel for num_threads(opt.num_threads) |
| 182 | + for (int q = 0; q < num_group; q++) |
| 183 | + { |
| 184 | + const Mat past_value_head = past_value.channel(q); |
| 185 | + const Mat cur_value_head = cur_value.channel(q); |
| 186 | + Mat value_head = value.channel(q); |
| 187 | + |
| 188 | + memcpy(value_head.row(0), past_value_head, out_embed_dim * past_seqlen * sizeof(float)); |
| 189 | + memcpy(value_head.row(past_seqlen), cur_value_head, out_embed_dim * cur_seqlen * sizeof(float)); |
| 190 | + } |
| 191 | + } |
| 192 | + else |
| 193 | + { |
| 194 | + value = cur_value; |
| 195 | + } |
| 196 | + |
| 197 | + Mat& top_blob = top_blobs[0]; |
| 198 | + top_blob.create(out_embed_dim, src_seqlen, num_heads, 4u, opt.blob_allocator); |
| 199 | + if (top_blob.empty()) |
| 200 | + return -100; |
| 201 | + |
| 202 | + const int num_heads_per_group = num_heads / num_group; |
| 203 | + |
| 204 | + Mat qk_cross(dst_seqlen, src_seqlen, num_heads, 4u, opt.workspace_allocator); |
| 205 | + if (qk_cross.empty()) |
| 206 | + return -100; |
| 207 | + |
| 208 | + std::vector<int> retqks(num_heads); |
| 209 | + |
| 210 | + // Dynamic Scale Calculation and Beta Correction |
| 211 | + Layer* _qk_gemm = qk_gemm; |
| 212 | + if (scale == 0.f) |
| 213 | + { |
| 214 | + float _scale = 1.f / sqrt(embed_dim); |
| 215 | + |
| 216 | + _qk_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); |
| 217 | + ncnn::ParamDict pd; |
| 218 | + |
| 219 | + pd.set(0, _scale); // alpha |
| 220 | + pd.set(1, 1.f / _scale); // beta |
| 221 | + pd.set(2, 0); // transA (Q: Seq x Embed) |
| 222 | + pd.set(3, 1); // transB (K: Seq x Embed -> K^T: Embed x Seq) => Q * K^T |
| 223 | + pd.set(4, 0); // constantA |
| 224 | + pd.set(5, 0); // constantB |
| 225 | + pd.set(6, attn_mask ? 0 : 1); // constantC (if mask exists, use it) |
| 226 | + pd.set(7, 0); // M |
| 227 | + pd.set(8, 0); // N |
| 228 | + pd.set(9, 0); // K |
| 229 | + pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C (MxN) |
| 230 | + pd.set(11, 0); // output_N1M |
| 231 | + pd.set(12, 1); // output_elempack |
| 232 | +#if NCNN_INT8 |
| 233 | + pd.set(18, int8_scale_term); |
| 234 | +#endif |
| 235 | + _qk_gemm->load_param(pd); |
| 236 | + _qk_gemm->load_model(ModelBinFromMatArray(0)); |
| 237 | + |
| 238 | + Option opt1 = opt; |
| 239 | + opt1.num_threads = 1; |
| 240 | + _qk_gemm->create_pipeline(opt1); |
| 241 | + } |
| 242 | + |
| 243 | + #pragma omp parallel for num_threads(opt.num_threads) |
| 244 | + for (int i = 0; i < num_heads; i++) |
| 245 | + { |
| 246 | + // 1. Q * K^T |
| 247 | + std::vector<Mat> qk_bottom_blobs; |
| 248 | + qk_bottom_blobs.push_back(query.channel(i)); // Q: [Seq, Embed] |
| 249 | + qk_bottom_blobs.push_back(key.channel(i / num_heads_per_group)); // K: [DstSeq, Embed] |
| 250 | + |
| 251 | + if (attn_mask) |
| 252 | + { |
| 253 | + // Ensure mask is 2D for Gemm auto-broadcast detection |
| 254 | + Mat maskm = attn_mask_blob; |
| 255 | + if (maskm.dims == 3) |
| 256 | + { |
| 257 | + // If c > 1, pick i-th head mask. If c == 1, pick 0-th (broadcast) |
| 258 | + maskm = maskm.channel(maskm.c > 1 ? i : 0); |
| 259 | + } |
| 260 | + qk_bottom_blobs.push_back(maskm); |
| 261 | + } |
| 262 | + |
| 263 | + std::vector<Mat> qk_top_blobs(1); |
| 264 | + qk_top_blobs[0] = qk_cross.channel(i); |
| 265 | + |
| 266 | + Option opt1 = opt; |
| 267 | + opt1.num_threads = 1; |
| 268 | + opt1.blob_allocator = qk_cross.allocator; |
| 269 | + retqks[i] = _qk_gemm->forward(qk_bottom_blobs, qk_top_blobs, opt1); |
| 270 | + } |
| 271 | + |
| 272 | + if (scale == 0.f) |
| 273 | + { |
| 274 | + Option opt1 = opt; |
| 275 | + opt1.num_threads = 1; |
| 276 | + _qk_gemm->destroy_pipeline(opt1); |
| 277 | + |
| 278 | + delete _qk_gemm; |
| 279 | + _qk_gemm = 0; |
| 280 | + } |
| 281 | + |
| 282 | + for (int i = 0; i < num_heads; i++) |
| 283 | + { |
| 284 | + if (retqks[i] != 0) |
| 285 | + return retqks[i]; |
| 286 | + } |
| 287 | + |
| 288 | + // 2. Softmax |
| 289 | + int retqk = qk_softmax->forward_inplace(qk_cross, opt); |
| 290 | + if (retqk != 0) |
| 291 | + return retqk; |
| 292 | + |
| 293 | + // 3. Attn * V |
| 294 | + std::vector<int> retqkvs(num_heads); |
| 295 | + |
| 296 | + #pragma omp parallel for num_threads(opt.num_threads) |
| 297 | + for (int i = 0; i < num_heads; i++) |
| 298 | + { |
| 299 | + std::vector<Mat> qkv_bottom_blobs(2); |
| 300 | + qkv_bottom_blobs[0] = qk_cross.channel(i); // Attn: [DstSeq, Seq] |
| 301 | + qkv_bottom_blobs[1] = value.channel(i / num_heads_per_group); // V: [DstSeq, OutEmbed] |
| 302 | + |
| 303 | + std::vector<Mat> qkv_top_blobs(1); |
| 304 | + qkv_top_blobs[0] = top_blob.channel(i); // Output |
| 305 | + |
| 306 | + Option opt1 = opt; |
| 307 | + opt1.num_threads = 1; |
| 308 | + retqkvs[i] = qkv_gemm->forward(qkv_bottom_blobs, qkv_top_blobs, opt1); |
| 309 | + } |
| 310 | + |
| 311 | + for (int i = 0; i < num_heads; i++) |
| 312 | + { |
| 313 | + if (retqkvs[i] != 0) |
| 314 | + return retqkvs[i]; |
| 315 | + } |
| 316 | + |
| 317 | + if (kv_cache) |
| 318 | + { |
| 319 | + top_blobs[1] = key; |
| 320 | + top_blobs[2] = value; |
| 321 | + } |
| 322 | + |
| 323 | + return 0; |
| 324 | +} |
| 325 | + |
| 326 | +} // namespace ncnn |
0 commit comments