@@ -105,7 +105,7 @@ static int test_sdpa_self_kvcache_decode(const ncnn::Mat& a)
105105 std::vector<ncnn::Mat> weights (0 );
106106
107107 std::vector<ncnn::Mat> as (3 );
108- as[0 ] = RandomMat (embed_dim, 1 , a.c );
108+ as[0 ] = RandomMat (embed_dim, cur_seqlen , a.c );
109109 as[1 ] = RandomMat (embed_dim, past_seqlen, a.c );
110110 as[2 ] = RandomMat (embed_dim, past_seqlen, a.c );
111111
@@ -127,10 +127,10 @@ static int test_sdpa_0()
127127 || test_sdpa_cross_kvcache (RandomMat (26 , 64 , 8 ), RandomMat (26 , 61 , 8 ), RandomMat (18 , 61 , 8 ), 1 )
128128 || test_sdpa_cross_kvcache (RandomMat (64 , 128 , 12 ), RandomMat (64 , 128 , 2 ), RandomMat (64 , 128 , 2 ), 0 )
129129 || test_sdpa_cross_kvcache (RandomMat (48 , 122 , 12 ), RandomMat (64 , 127 , 2 ), RandomMat (64 , 127 , 2 ), 1 )
130- || test_sdpa_cross_kvcache (RandomMat (44 , 128 , 4 ), RandomMat (44 , 123 , 4 ), RandomMat (55 , 123 , 4 ), 0 , 1 . f )
131- || test_sdpa_cross_kvcache (RandomMat (12 , 127 , 4 ), RandomMat (12 , 127 , 4 ), RandomMat (55 , 127 , 4 ), 1 , 1 . f )
132- || test_sdpa_cross_kvcache (RandomMat (28 , 17 , 15 ), RandomMat (28 , 127 , 5 ), RandomMat (32 , 127 , 5 ), 0 , 0 . 1f )
133- || test_sdpa_cross_kvcache (RandomMat (28 , 17 , 15 ), RandomMat (28 , 32 , 5 ), RandomMat (11 , 32 , 5 ), 1 , - 0 . 4f );
130+ || test_sdpa_cross_kvcache (RandomMat (44 , 128 , 4 ), RandomMat (44 , 123 , 4 ), RandomMat (55 , 123 , 4 ), 0 )
131+ || test_sdpa_cross_kvcache (RandomMat (12 , 127 , 4 ), RandomMat (12 , 127 , 4 ), RandomMat (55 , 127 , 4 ), 1 )
132+ || test_sdpa_cross_kvcache (RandomMat (28 , 17 , 15 ), RandomMat (28 , 127 , 5 ), RandomMat (32 , 127 , 5 ), 0 )
133+ || test_sdpa_cross_kvcache (RandomMat (28 , 17 , 15 ), RandomMat (28 , 32 , 5 ), RandomMat (11 , 32 , 5 ), 1 );
134134}
135135
136136static int test_sdpa_1 ()
@@ -255,7 +255,7 @@ static int test_sdpa_int8_self_kvcache_decode(const ncnn::Mat& a)
255255 std::vector<ncnn::Mat> weights (0 );
256256
257257 std::vector<ncnn::Mat> as (3 );
258- as[0 ] = RandomMat (embed_dim, 1 , a.c );
258+ as[0 ] = RandomMat (embed_dim, cur_seqlen , a.c );
259259 as[1 ] = RandomMat (embed_dim, past_seqlen, a.c );
260260 as[2 ] = RandomMat (embed_dim, past_seqlen, a.c );
261261
@@ -277,10 +277,10 @@ static int test_sdpa_3()
277277 || test_sdpa_int8_cross_kvcache (RandomMat (26 , 64 , 8 ), RandomMat (26 , 61 , 8 ), RandomMat (18 , 61 , 8 ), 1 )
278278 || test_sdpa_int8_cross_kvcache (RandomMat (64 , 128 , 12 ), RandomMat (64 , 128 , 2 ), RandomMat (64 , 128 , 2 ), 0 )
279279 || test_sdpa_int8_cross_kvcache (RandomMat (48 , 122 , 12 ), RandomMat (64 , 127 , 2 ), RandomMat (64 , 127 , 2 ), 1 )
280- || test_sdpa_int8_cross_kvcache (RandomMat (44 , 128 , 4 ), RandomMat (44 , 123 , 4 ), RandomMat (55 , 123 , 4 ), 0 , 1 . f )
281- || test_sdpa_int8_cross_kvcache (RandomMat (12 , 127 , 4 ), RandomMat (12 , 127 , 4 ), RandomMat (55 , 127 , 4 ), 1 , 1 . f )
282- || test_sdpa_int8_cross_kvcache (RandomMat (28 , 17 , 15 ), RandomMat (28 , 127 , 5 ), RandomMat (32 , 127 , 5 ), 0 , 0 . 1f )
283- || test_sdpa_int8_cross_kvcache (RandomMat (28 , 17 , 15 ), RandomMat (28 , 32 , 5 ), RandomMat (11 , 32 , 5 ), 1 , - 0 . 4f );
280+ || test_sdpa_int8_cross_kvcache (RandomMat (44 , 128 , 4 ), RandomMat (44 , 123 , 4 ), RandomMat (55 , 123 , 4 ), 0 )
281+ || test_sdpa_int8_cross_kvcache (RandomMat (12 , 127 , 4 ), RandomMat (12 , 127 , 4 ), RandomMat (55 , 127 , 4 ), 1 )
282+ || test_sdpa_int8_cross_kvcache (RandomMat (28 , 17 , 15 ), RandomMat (28 , 127 , 5 ), RandomMat (32 , 127 , 5 ), 0 )
283+ || test_sdpa_int8_cross_kvcache (RandomMat (28 , 17 , 15 ), RandomMat (28 , 32 , 5 ), RandomMat (11 , 32 , 5 ), 1 );
284284}
285285
286286static int test_sdpa_4 ()
0 commit comments