@@ -35,17 +35,54 @@ static int test_sdpa_oom(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat
3535 return ret;
3636}
3737
38+ static int test_sdpa_kvcache_oom (const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int attn_mask, int past_seqlen)
39+ {
40+ const int embed_dim = q.w ;
41+ const int out_embed_dim = v.w ;
42+ const int src_seqlen = q.h ;
43+ const int cur_seqlen = k.h ;
44+ const int dst_seqlen = past_seqlen + cur_seqlen;
45+
46+ ncnn::ParamDict pd;
47+ pd.set (5 , attn_mask);
48+ pd.set (7 , 1 ); // kv_cache
49+
50+ std::vector<ncnn::Mat> weights (0 );
51+
52+ std::vector<ncnn::Mat> as (3 );
53+ as[0 ] = q;
54+ as[1 ] = k;
55+ as[2 ] = v;
56+
57+ if (attn_mask)
58+ {
59+ as.push_back (RandomMat (dst_seqlen, src_seqlen));
60+ }
61+
62+ as.push_back (RandomMat (embed_dim, past_seqlen, k.c ));
63+ as.push_back (RandomMat (out_embed_dim, past_seqlen, v.c ));
64+
65+ int ret = test_layer_oom (" SDPA" , pd, weights, as, 3 );
66+ if (ret != 0 )
67+ {
68+ fprintf (stderr, " test_sdpa_kvcache_oom failed q=(%d %d %d) k=(%d %d %d) v=(%d %d %d) attn_mask=%d past_seqlen=%d\n " , q.w , q.h , q.c , k.w , k.h , k.c , v.w , v.h , v.c , attn_mask, past_seqlen);
69+ }
70+
71+ return ret;
72+ }
73+
3874static int test_sdpa_0 ()
3975{
4076 return 0
4177 || test_sdpa_oom (RandomMat (32 , 66 , 8 ), RandomMat (32 , 66 , 8 ), RandomMat (20 , 66 , 8 ), 0 )
4278 || test_sdpa_oom (RandomMat (26 , 64 , 8 ), RandomMat (26 , 61 , 8 ), RandomMat (18 , 61 , 8 ), 1 )
4379 || test_sdpa_oom (RandomMat (28 , 17 , 15 ), RandomMat (28 , 127 , 5 ), RandomMat (32 , 127 , 5 ), 0 , 0 .1f )
44- || test_sdpa_oom (RandomMat (28 , 17 , 15 ), RandomMat (28 , 32 , 5 ), RandomMat (11 , 32 , 5 ), 1 , -0 .4f );
80+ || test_sdpa_oom (RandomMat (28 , 17 , 15 ), RandomMat (28 , 32 , 5 ), RandomMat (11 , 32 , 5 ), 1 , -0 .4f )
81+ || test_sdpa_kvcache_oom (RandomMat (28 , 17 , 15 ), RandomMat (28 , 127 , 5 ), RandomMat (32 , 127 , 5 ), 0 , 3 );
4582}
4683
4784#if NCNN_INT8
48- static int test_sdpa_oom_int8 (const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int attn_mask, float scale = 0 .f)
85+ static int test_sdpa_int8_oom (const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int attn_mask, float scale = 0 .f)
4986{
5087 const int src_seqlen = q.h ;
5188 const int dst_seqlen = k.h ;
@@ -72,7 +109,44 @@ static int test_sdpa_oom_int8(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn
72109 int ret = test_layer_oom (" SDPA" , pd, weights, as, 1 , epsilon);
73110 if (ret != 0 )
74111 {
75- fprintf (stderr, " test_sdpa_oom_int8 failed q=(%d %d %d) k=(%d %d %d) v=(%d %d %d) attn_mask=%d scale=%f\n " , q.w , q.h , q.c , k.w , k.h , k.c , v.w , v.h , v.c , attn_mask, scale);
112+ fprintf (stderr, " test_sdpa_int8_oom failed q=(%d %d %d) k=(%d %d %d) v=(%d %d %d) attn_mask=%d scale=%f\n " , q.w , q.h , q.c , k.w , k.h , k.c , v.w , v.h , v.c , attn_mask, scale);
113+ }
114+
115+ return ret;
116+ }
117+
118+ static int test_sdpa_int8_kvcache_oom (const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int attn_mask, int past_seqlen)
119+ {
120+ const int embed_dim = q.w ;
121+ const int out_embed_dim = v.w ;
122+ const int src_seqlen = q.h ;
123+ const int cur_seqlen = k.h ;
124+ const int dst_seqlen = past_seqlen + cur_seqlen;
125+
126+ ncnn::ParamDict pd;
127+ pd.set (5 , attn_mask);
128+ pd.set (7 , 1 ); // kv_cache
129+ pd.set (18 , 2 ); // int8_scale_term
130+
131+ std::vector<ncnn::Mat> weights (0 );
132+
133+ std::vector<ncnn::Mat> as (3 );
134+ as[0 ] = q;
135+ as[1 ] = k;
136+ as[2 ] = v;
137+
138+ if (attn_mask)
139+ {
140+ as.push_back (RandomMat (dst_seqlen, src_seqlen));
141+ }
142+
143+ as.push_back (RandomMat (embed_dim, past_seqlen, k.c ));
144+ as.push_back (RandomMat (out_embed_dim, past_seqlen, v.c ));
145+
146+ int ret = test_layer_oom (" SDPA" , pd, weights, as, 3 );
147+ if (ret != 0 )
148+ {
149+ fprintf (stderr, " test_sdpa_int8_kvcache_oom failed q=(%d %d %d) k=(%d %d %d) v=(%d %d %d) attn_mask=%d past_seqlen=%d\n " , q.w , q.h , q.c , k.w , k.h , k.c , v.w , v.h , v.c , attn_mask, past_seqlen);
76150 }
77151
78152 return ret;
@@ -81,10 +155,11 @@ static int test_sdpa_oom_int8(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn
81155static int test_sdpa_1 ()
82156{
83157 return 0
84- || test_sdpa_oom_int8 (RandomMat (32 , 66 , 8 ), RandomMat (32 , 66 , 8 ), RandomMat (20 , 66 , 8 ), 0 )
85- || test_sdpa_oom_int8 (RandomMat (26 , 64 , 8 ), RandomMat (26 , 61 , 8 ), RandomMat (18 , 61 , 8 ), 1 )
86- || test_sdpa_oom_int8 (RandomMat (28 , 17 , 15 ), RandomMat (28 , 127 , 5 ), RandomMat (32 , 127 , 5 ), 0 , 0 .1f )
87- || test_sdpa_oom_int8 (RandomMat (28 , 17 , 15 ), RandomMat (28 , 32 , 5 ), RandomMat (11 , 32 , 5 ), 1 , -0 .4f );
158+ || test_sdpa_int8_oom (RandomMat (32 , 66 , 8 ), RandomMat (32 , 66 , 8 ), RandomMat (20 , 66 , 8 ), 0 )
159+ || test_sdpa_int8_oom (RandomMat (26 , 64 , 8 ), RandomMat (26 , 61 , 8 ), RandomMat (18 , 61 , 8 ), 1 )
160+ || test_sdpa_int8_oom (RandomMat (28 , 17 , 15 ), RandomMat (28 , 127 , 5 ), RandomMat (32 , 127 , 5 ), 0 , 0 .1f )
161+ || test_sdpa_int8_oom (RandomMat (28 , 17 , 15 ), RandomMat (28 , 32 , 5 ), RandomMat (11 , 32 , 5 ), 1 , -0 .4f )
162+ || test_sdpa_int8_kvcache_oom (RandomMat (28 , 17 , 15 ), RandomMat (28 , 127 , 5 ), RandomMat (32 , 127 , 5 ), 0 , 3 );
88163}
89164#endif
90165
0 commit comments