Skip to content

Commit 9d294b2

Browse files
authored
spda x86 optimization (using gemm & softmax) (#6421)
1 parent 14d9715 commit 9d294b2

File tree

6 files changed

+363
-3
lines changed

6 files changed

+363
-3
lines changed

src/layer/sdpa.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ int SDPA::forward_int8(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& t
323323
if (qk_cross_int8.empty())
324324
return -100;
325325

326-
Mat query_or_qk_cross_int8_scales(src_seqlen, opt.num_threads, 4u, opt.workspace_allocator);
326+
Mat query_or_qk_cross_int8_scales(src_seqlen, 1, opt.num_threads, 4u, opt.workspace_allocator);
327327
if (query_or_qk_cross_int8_scales.empty())
328328
return -100;
329329

src/layer/x86/gemm_x86.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7460,6 +7460,8 @@ int Gemm_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
74607460
{
74617461
Mat C2;
74627462
C2.create_like(C, opt.workspace_allocator);
7463+
if (C2.empty())
7464+
return -100;
74637465

74647466
const int size = C.total() * C.elempack;
74657467
for (int i = 0; i < size; i++)

src/layer/x86/sdpa_x86.cpp

Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
// Copyright 2025 Tencent
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
#include "sdpa_x86.h"
5+
6+
#include "layer_type.h"
7+
8+
namespace ncnn {
9+
10+
SDPA_x86::SDPA_x86()
11+
{
12+
qk_gemm = 0;
13+
qkv_gemm = 0;
14+
qk_softmax = 0;
15+
}
16+
17+
int SDPA_x86::create_pipeline(const Option& _opt)
18+
{
19+
Option opt = _opt;
20+
if (int8_scale_term)
21+
{
22+
opt.use_packing_layout = false; // TODO enable packing
23+
}
24+
25+
{
26+
qk_softmax = ncnn::create_layer_cpu(ncnn::LayerType::Softmax);
27+
ncnn::ParamDict pd;
28+
pd.set(0, -1); // axis
29+
pd.set(1, 1);
30+
qk_softmax->load_param(pd);
31+
qk_softmax->load_model(ModelBinFromMatArray(0));
32+
qk_softmax->create_pipeline(opt);
33+
}
34+
35+
// Q * K^T
36+
if (scale != 0.f)
37+
{
38+
qk_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm);
39+
ncnn::ParamDict pd;
40+
41+
pd.set(0, scale); // alpha
42+
pd.set(1, 1.f / scale); // beta
43+
pd.set(2, 0); // transA (Q: Seq x Embed)
44+
pd.set(3, 1); // transB (K: Seq x Embed -> K^T: Embed x Seq) => Q * K^T
45+
pd.set(4, 0); // constantA
46+
pd.set(5, 0); // constantB
47+
pd.set(6, attn_mask ? 0 : 1); // constantC (if mask exists, use it)
48+
pd.set(7, 0); // M
49+
pd.set(8, 0); // N
50+
pd.set(9, 0); // K
51+
pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C (MxN)
52+
pd.set(11, 0); // output_N1M
53+
pd.set(12, 1); // output_elempack
54+
#if NCNN_INT8
55+
pd.set(18, int8_scale_term);
56+
#endif
57+
qk_gemm->load_param(pd);
58+
qk_gemm->load_model(ModelBinFromMatArray(0));
59+
Option opt1 = opt;
60+
opt1.num_threads = 1;
61+
qk_gemm->create_pipeline(opt1);
62+
}
63+
64+
// Attn * V
65+
{
66+
qkv_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm);
67+
ncnn::ParamDict pd;
68+
pd.set(0, 1.f); // alpha
69+
pd.set(1, 1.f); // beta
70+
pd.set(2, 0); // transA (Attn: Seq x Seq)
71+
pd.set(3, 0); // transB (V: Seq x Embed) => Attn * V
72+
pd.set(4, 0); // constantA
73+
pd.set(5, 0); // constantB
74+
pd.set(6, 1); // constantC (None)
75+
pd.set(7, 0); // M
76+
pd.set(8, 0); // N
77+
pd.set(9, 0); // K
78+
pd.set(10, -1); // constant_broadcast_type_C
79+
pd.set(11, 0); // output_N1M
80+
pd.set(12, 1); // output_elempack
81+
pd.set(14, 0); // output_transpose
82+
#if NCNN_INT8
83+
pd.set(18, int8_scale_term);
84+
#endif
85+
qkv_gemm->load_param(pd);
86+
qkv_gemm->load_model(ModelBinFromMatArray(0));
87+
Option opt1 = opt;
88+
opt1.num_threads = 1;
89+
qkv_gemm->create_pipeline(opt1);
90+
}
91+
92+
return 0;
93+
}
94+
95+
int SDPA_x86::destroy_pipeline(const Option& _opt)
96+
{
97+
Option opt = _opt;
98+
if (int8_scale_term)
99+
{
100+
opt.use_packing_layout = false; // TODO enable packing
101+
}
102+
103+
if (qk_softmax)
104+
{
105+
qk_softmax->destroy_pipeline(opt);
106+
delete qk_softmax;
107+
qk_softmax = 0;
108+
}
109+
110+
if (qk_gemm)
111+
{
112+
qk_gemm->destroy_pipeline(opt);
113+
delete qk_gemm;
114+
qk_gemm = 0;
115+
}
116+
117+
if (qkv_gemm)
118+
{
119+
qkv_gemm->destroy_pipeline(opt);
120+
delete qkv_gemm;
121+
qkv_gemm = 0;
122+
}
123+
124+
return 0;
125+
}
126+
127+
int SDPA_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& _opt) const
128+
{
129+
Option opt = _opt;
130+
if (int8_scale_term)
131+
{
132+
opt.use_packing_layout = false; // TODO enable packing
133+
}
134+
135+
const Mat& query = bottom_blobs[0];
136+
const Mat& cur_key = bottom_blobs[1];
137+
const Mat& cur_value = bottom_blobs[2];
138+
const Mat& attn_mask_blob = attn_mask ? bottom_blobs[3] : Mat();
139+
const Mat& past_key = kv_cache ? bottom_blobs[attn_mask ? 4 : 3] : Mat();
140+
const Mat& past_value = kv_cache ? bottom_blobs[attn_mask ? 5 : 4] : Mat();
141+
142+
const int embed_dim = query.w;
143+
const int src_seqlen = query.h;
144+
const int num_heads = query.c;
145+
const int cur_seqlen = cur_key.h;
146+
const int num_group = cur_key.c;
147+
const int out_embed_dim = cur_value.w;
148+
const int past_seqlen = kv_cache ? past_key.h : 0;
149+
const int dst_seqlen = past_seqlen + cur_seqlen;
150+
151+
Mat key;
152+
if (past_seqlen > 0)
153+
{
154+
key.create(embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator);
155+
if (key.empty())
156+
return -100;
157+
158+
#pragma omp parallel for num_threads(opt.num_threads)
159+
for (int q = 0; q < num_group; q++)
160+
{
161+
const Mat past_key_head = past_key.channel(q);
162+
const Mat cur_key_head = cur_key.channel(q);
163+
Mat key_head = key.channel(q);
164+
165+
memcpy(key_head.row(0), past_key_head, embed_dim * past_seqlen * sizeof(float));
166+
memcpy(key_head.row(past_seqlen), cur_key_head, embed_dim * cur_seqlen * sizeof(float));
167+
}
168+
}
169+
else
170+
{
171+
key = cur_key;
172+
}
173+
174+
Mat value;
175+
if (past_seqlen > 0)
176+
{
177+
value.create(out_embed_dim, dst_seqlen, num_group, 4u, opt.blob_allocator);
178+
if (value.empty())
179+
return -100;
180+
181+
#pragma omp parallel for num_threads(opt.num_threads)
182+
for (int q = 0; q < num_group; q++)
183+
{
184+
const Mat past_value_head = past_value.channel(q);
185+
const Mat cur_value_head = cur_value.channel(q);
186+
Mat value_head = value.channel(q);
187+
188+
memcpy(value_head.row(0), past_value_head, out_embed_dim * past_seqlen * sizeof(float));
189+
memcpy(value_head.row(past_seqlen), cur_value_head, out_embed_dim * cur_seqlen * sizeof(float));
190+
}
191+
}
192+
else
193+
{
194+
value = cur_value;
195+
}
196+
197+
Mat& top_blob = top_blobs[0];
198+
top_blob.create(out_embed_dim, src_seqlen, num_heads, 4u, opt.blob_allocator);
199+
if (top_blob.empty())
200+
return -100;
201+
202+
const int num_heads_per_group = num_heads / num_group;
203+
204+
Mat qk_cross(dst_seqlen, src_seqlen, num_heads, 4u, opt.workspace_allocator);
205+
if (qk_cross.empty())
206+
return -100;
207+
208+
std::vector<int> retqks(num_heads);
209+
210+
// Dynamic Scale Calculation and Beta Correction
211+
Layer* _qk_gemm = qk_gemm;
212+
if (scale == 0.f)
213+
{
214+
float _scale = 1.f / sqrt(embed_dim);
215+
216+
_qk_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm);
217+
ncnn::ParamDict pd;
218+
219+
pd.set(0, _scale); // alpha
220+
pd.set(1, 1.f / _scale); // beta
221+
pd.set(2, 0); // transA (Q: Seq x Embed)
222+
pd.set(3, 1); // transB (K: Seq x Embed -> K^T: Embed x Seq) => Q * K^T
223+
pd.set(4, 0); // constantA
224+
pd.set(5, 0); // constantB
225+
pd.set(6, attn_mask ? 0 : 1); // constantC (if mask exists, use it)
226+
pd.set(7, 0); // M
227+
pd.set(8, 0); // N
228+
pd.set(9, 0); // K
229+
pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C (MxN)
230+
pd.set(11, 0); // output_N1M
231+
pd.set(12, 1); // output_elempack
232+
#if NCNN_INT8
233+
pd.set(18, int8_scale_term);
234+
#endif
235+
_qk_gemm->load_param(pd);
236+
_qk_gemm->load_model(ModelBinFromMatArray(0));
237+
238+
Option opt1 = opt;
239+
opt1.num_threads = 1;
240+
_qk_gemm->create_pipeline(opt1);
241+
}
242+
243+
#pragma omp parallel for num_threads(opt.num_threads)
244+
for (int i = 0; i < num_heads; i++)
245+
{
246+
// 1. Q * K^T
247+
std::vector<Mat> qk_bottom_blobs;
248+
qk_bottom_blobs.push_back(query.channel(i)); // Q: [Seq, Embed]
249+
qk_bottom_blobs.push_back(key.channel(i / num_heads_per_group)); // K: [DstSeq, Embed]
250+
251+
if (attn_mask)
252+
{
253+
// Ensure mask is 2D for Gemm auto-broadcast detection
254+
Mat maskm = attn_mask_blob;
255+
if (maskm.dims == 3)
256+
{
257+
// If c > 1, pick i-th head mask. If c == 1, pick 0-th (broadcast)
258+
maskm = maskm.channel(maskm.c > 1 ? i : 0);
259+
}
260+
qk_bottom_blobs.push_back(maskm);
261+
}
262+
263+
std::vector<Mat> qk_top_blobs(1);
264+
qk_top_blobs[0] = qk_cross.channel(i);
265+
266+
Option opt1 = opt;
267+
opt1.num_threads = 1;
268+
opt1.blob_allocator = qk_cross.allocator;
269+
retqks[i] = _qk_gemm->forward(qk_bottom_blobs, qk_top_blobs, opt1);
270+
}
271+
272+
if (scale == 0.f)
273+
{
274+
Option opt1 = opt;
275+
opt1.num_threads = 1;
276+
_qk_gemm->destroy_pipeline(opt1);
277+
278+
delete _qk_gemm;
279+
_qk_gemm = 0;
280+
}
281+
282+
for (int i = 0; i < num_heads; i++)
283+
{
284+
if (retqks[i] != 0)
285+
return retqks[i];
286+
}
287+
288+
// 2. Softmax
289+
int retqk = qk_softmax->forward_inplace(qk_cross, opt);
290+
if (retqk != 0)
291+
return retqk;
292+
293+
// 3. Attn * V
294+
std::vector<int> retqkvs(num_heads);
295+
296+
#pragma omp parallel for num_threads(opt.num_threads)
297+
for (int i = 0; i < num_heads; i++)
298+
{
299+
std::vector<Mat> qkv_bottom_blobs(2);
300+
qkv_bottom_blobs[0] = qk_cross.channel(i); // Attn: [DstSeq, Seq]
301+
qkv_bottom_blobs[1] = value.channel(i / num_heads_per_group); // V: [DstSeq, OutEmbed]
302+
303+
std::vector<Mat> qkv_top_blobs(1);
304+
qkv_top_blobs[0] = top_blob.channel(i); // Output
305+
306+
Option opt1 = opt;
307+
opt1.num_threads = 1;
308+
retqkvs[i] = qkv_gemm->forward(qkv_bottom_blobs, qkv_top_blobs, opt1);
309+
}
310+
311+
for (int i = 0; i < num_heads; i++)
312+
{
313+
if (retqkvs[i] != 0)
314+
return retqkvs[i];
315+
}
316+
317+
if (kv_cache)
318+
{
319+
top_blobs[1] = key;
320+
top_blobs[2] = value;
321+
}
322+
323+
return 0;
324+
}
325+
326+
} // namespace ncnn

src/layer/x86/sdpa_x86.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright 2025 Tencent
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
#ifndef LAYER_SDPA_X86_H
5+
#define LAYER_SDPA_X86_H
6+
7+
#include "sdpa.h"
8+
9+
namespace ncnn {
10+
11+
class SDPA_x86 : public SDPA
12+
{
13+
public:
14+
SDPA_x86();
15+
16+
virtual int create_pipeline(const Option& opt);
17+
virtual int destroy_pipeline(const Option& opt);
18+
19+
virtual int forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const;
20+
21+
public:
22+
Layer* qk_gemm;
23+
Layer* qkv_gemm;
24+
25+
Layer* qk_softmax;
26+
};
27+
28+
} // namespace ncnn
29+
30+
#endif // LAYER_SDPA_X86_H

tests/test_sdpa.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ static int test_sdpa_int8(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Ma
7171
as.push_back(RandomMat(dst_seqlen, src_seqlen));
7272
}
7373

74-
float epsilon = 0.001;
74+
float epsilon = 0.01;
7575

7676
int ret = test_layer("SDPA", pd, weights, as, 1, epsilon);
7777
if (ret != 0)

tests/test_sdpa_kvcache.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ static int test_sdpa_int8_kvcache(const ncnn::Mat& q, const ncnn::Mat& k, const
8383
as.push_back(RandomMat(embed_dim, past_seqlen, k.c));
8484
as.push_back(RandomMat(out_embed_dim, past_seqlen, v.c));
8585

86-
int ret = test_layer("SDPA", pd, weights, as, 3);
86+
float epsilon = 0.01;
87+
88+
int ret = test_layer("SDPA", pd, weights, as, 3, epsilon);
8789
if (ret != 0)
8890
{
8991
fprintf(stderr, "test_sdpa_int8_kvcache failed q=(%d %d %d) k=(%d %d %d) v=(%d %d %d) attn_mask=%d past_seqlen=%d\n", q.w, q.h, q.c, k.w, k.h, k.c, v.w, v.h, v.c, attn_mask, past_seqlen);

0 commit comments

Comments
 (0)