Skip to content

Commit 1997510

Browse files
committed
oom test++
1 parent 2ce9b77 commit 1997510

File tree

1 file changed

+82
-7
lines changed

1 file changed

+82
-7
lines changed

tests/test_sdpa_oom.cpp

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,54 @@ static int test_sdpa_oom(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat
3535
return ret;
3636
}
3737

38+
static int test_sdpa_kvcache_oom(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int attn_mask, int past_seqlen)
39+
{
40+
const int embed_dim = q.w;
41+
const int out_embed_dim = v.w;
42+
const int src_seqlen = q.h;
43+
const int cur_seqlen = k.h;
44+
const int dst_seqlen = past_seqlen + cur_seqlen;
45+
46+
ncnn::ParamDict pd;
47+
pd.set(5, attn_mask);
48+
pd.set(7, 1); // kv_cache
49+
50+
std::vector<ncnn::Mat> weights(0);
51+
52+
std::vector<ncnn::Mat> as(3);
53+
as[0] = q;
54+
as[1] = k;
55+
as[2] = v;
56+
57+
if (attn_mask)
58+
{
59+
as.push_back(RandomMat(dst_seqlen, src_seqlen));
60+
}
61+
62+
as.push_back(RandomMat(embed_dim, past_seqlen, k.c));
63+
as.push_back(RandomMat(out_embed_dim, past_seqlen, v.c));
64+
65+
int ret = test_layer_oom("SDPA", pd, weights, as, 3);
66+
if (ret != 0)
67+
{
68+
fprintf(stderr, "test_sdpa_kvcache_oom failed q=(%d %d %d) k=(%d %d %d) v=(%d %d %d) attn_mask=%d past_seqlen=%d\n", q.w, q.h, q.c, k.w, k.h, k.c, v.w, v.h, v.c, attn_mask, past_seqlen);
69+
}
70+
71+
return ret;
72+
}
73+
3874
static int test_sdpa_0()
3975
{
4076
return 0
4177
|| test_sdpa_oom(RandomMat(32, 66, 8), RandomMat(32, 66, 8), RandomMat(20, 66, 8), 0)
4278
|| test_sdpa_oom(RandomMat(26, 64, 8), RandomMat(26, 61, 8), RandomMat(18, 61, 8), 1)
4379
|| test_sdpa_oom(RandomMat(28, 17, 15), RandomMat(28, 127, 5), RandomMat(32, 127, 5), 0, 0.1f)
44-
|| test_sdpa_oom(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1, -0.4f);
80+
|| test_sdpa_oom(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1, -0.4f)
81+
|| test_sdpa_kvcache_oom(RandomMat(28, 17, 15), RandomMat(28, 127, 5), RandomMat(32, 127, 5), 0, 3);
4582
}
4683

4784
#if NCNN_INT8
48-
static int test_sdpa_oom_int8(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int attn_mask, float scale = 0.f)
85+
static int test_sdpa_int8_oom(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int attn_mask, float scale = 0.f)
4986
{
5087
const int src_seqlen = q.h;
5188
const int dst_seqlen = k.h;
@@ -72,7 +109,44 @@ static int test_sdpa_oom_int8(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn
72109
int ret = test_layer_oom("SDPA", pd, weights, as, 1, epsilon);
73110
if (ret != 0)
74111
{
75-
fprintf(stderr, "test_sdpa_oom_int8 failed q=(%d %d %d) k=(%d %d %d) v=(%d %d %d) attn_mask=%d scale=%f\n", q.w, q.h, q.c, k.w, k.h, k.c, v.w, v.h, v.c, attn_mask, scale);
112+
fprintf(stderr, "test_sdpa_int8_oom failed q=(%d %d %d) k=(%d %d %d) v=(%d %d %d) attn_mask=%d scale=%f\n", q.w, q.h, q.c, k.w, k.h, k.c, v.w, v.h, v.c, attn_mask, scale);
113+
}
114+
115+
return ret;
116+
}
117+
118+
static int test_sdpa_int8_kvcache_oom(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int attn_mask, int past_seqlen)
119+
{
120+
const int embed_dim = q.w;
121+
const int out_embed_dim = v.w;
122+
const int src_seqlen = q.h;
123+
const int cur_seqlen = k.h;
124+
const int dst_seqlen = past_seqlen + cur_seqlen;
125+
126+
ncnn::ParamDict pd;
127+
pd.set(5, attn_mask);
128+
pd.set(7, 1); // kv_cache
129+
pd.set(18, 2); // int8_scale_term
130+
131+
std::vector<ncnn::Mat> weights(0);
132+
133+
std::vector<ncnn::Mat> as(3);
134+
as[0] = q;
135+
as[1] = k;
136+
as[2] = v;
137+
138+
if (attn_mask)
139+
{
140+
as.push_back(RandomMat(dst_seqlen, src_seqlen));
141+
}
142+
143+
as.push_back(RandomMat(embed_dim, past_seqlen, k.c));
144+
as.push_back(RandomMat(out_embed_dim, past_seqlen, v.c));
145+
146+
int ret = test_layer_oom("SDPA", pd, weights, as, 3);
147+
if (ret != 0)
148+
{
149+
fprintf(stderr, "test_sdpa_int8_kvcache_oom failed q=(%d %d %d) k=(%d %d %d) v=(%d %d %d) attn_mask=%d past_seqlen=%d\n", q.w, q.h, q.c, k.w, k.h, k.c, v.w, v.h, v.c, attn_mask, past_seqlen);
76150
}
77151

78152
return ret;
@@ -81,10 +155,11 @@ static int test_sdpa_oom_int8(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn
81155
static int test_sdpa_1()
82156
{
83157
return 0
84-
|| test_sdpa_oom_int8(RandomMat(32, 66, 8), RandomMat(32, 66, 8), RandomMat(20, 66, 8), 0)
85-
|| test_sdpa_oom_int8(RandomMat(26, 64, 8), RandomMat(26, 61, 8), RandomMat(18, 61, 8), 1)
86-
|| test_sdpa_oom_int8(RandomMat(28, 17, 15), RandomMat(28, 127, 5), RandomMat(32, 127, 5), 0, 0.1f)
87-
|| test_sdpa_oom_int8(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1, -0.4f);
158+
|| test_sdpa_int8_oom(RandomMat(32, 66, 8), RandomMat(32, 66, 8), RandomMat(20, 66, 8), 0)
159+
|| test_sdpa_int8_oom(RandomMat(26, 64, 8), RandomMat(26, 61, 8), RandomMat(18, 61, 8), 1)
160+
|| test_sdpa_int8_oom(RandomMat(28, 17, 15), RandomMat(28, 127, 5), RandomMat(32, 127, 5), 0, 0.1f)
161+
|| test_sdpa_int8_oom(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1, -0.4f)
162+
|| test_sdpa_int8_kvcache_oom(RandomMat(28, 17, 15), RandomMat(28, 127, 5), RandomMat(32, 127, 5), 0, 3);
88163
}
89164
#endif
90165

0 commit comments

Comments
 (0)