Skip to content

Commit 050e83e

Browse files
committed
f
1 parent de17a53 commit 050e83e

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

src/layer/sdpa.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ int SDPA::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
6262
const int dst_seqlen = past_seqlen > 0 ? (query_i == key_i ? (past_seqlen + cur_seqlen) : past_seqlen) : cur_seqlen;
6363

6464
// assert cur_key.w == embed_dim
65-
// assert cur_key.h == cur_value.h == dst_seqlen
65+
// assert cur_key.h == cur_value.h == cur_seqlen
6666
// assert cur_value.c == num_group
6767
// assert num_heads % num_group == 0
6868

@@ -441,7 +441,7 @@ int SDPA::forward_int8(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& t
441441
const int dst_seqlen = past_seqlen > 0 ? (query_i == key_i ? (past_seqlen + cur_seqlen) : past_seqlen) : cur_seqlen;
442442

443443
// assert cur_key.w == embed_dim
444-
// assert cur_key.h == cur_value.h == dst_seqlen
444+
// assert cur_key.h == cur_value.h == cur_seqlen
445445
// assert cur_value.c == num_group
446446
// assert num_heads % num_group == 0
447447

tests/test_sdpa_kvcache.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ static int test_sdpa_self_kvcache_decode(const ncnn::Mat& a)
105105
std::vector<ncnn::Mat> weights(0);
106106

107107
std::vector<ncnn::Mat> as(3);
108-
as[0] = RandomMat(embed_dim, 1, a.c);
108+
as[0] = RandomMat(embed_dim, cur_seqlen, a.c);
109109
as[1] = RandomMat(embed_dim, past_seqlen, a.c);
110110
as[2] = RandomMat(embed_dim, past_seqlen, a.c);
111111

@@ -127,10 +127,10 @@ static int test_sdpa_0()
127127
|| test_sdpa_cross_kvcache(RandomMat(26, 64, 8), RandomMat(26, 61, 8), RandomMat(18, 61, 8), 1)
128128
|| test_sdpa_cross_kvcache(RandomMat(64, 128, 12), RandomMat(64, 128, 2), RandomMat(64, 128, 2), 0)
129129
|| test_sdpa_cross_kvcache(RandomMat(48, 122, 12), RandomMat(64, 127, 2), RandomMat(64, 127, 2), 1)
130-
|| test_sdpa_cross_kvcache(RandomMat(44, 128, 4), RandomMat(44, 123, 4), RandomMat(55, 123, 4), 0, 1.f)
131-
|| test_sdpa_cross_kvcache(RandomMat(12, 127, 4), RandomMat(12, 127, 4), RandomMat(55, 127, 4), 1, 1.f)
132-
|| test_sdpa_cross_kvcache(RandomMat(28, 17, 15), RandomMat(28, 127, 5), RandomMat(32, 127, 5), 0, 0.1f)
133-
|| test_sdpa_cross_kvcache(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1, -0.4f);
130+
|| test_sdpa_cross_kvcache(RandomMat(44, 128, 4), RandomMat(44, 123, 4), RandomMat(55, 123, 4), 0)
131+
|| test_sdpa_cross_kvcache(RandomMat(12, 127, 4), RandomMat(12, 127, 4), RandomMat(55, 127, 4), 1)
132+
|| test_sdpa_cross_kvcache(RandomMat(28, 17, 15), RandomMat(28, 127, 5), RandomMat(32, 127, 5), 0)
133+
|| test_sdpa_cross_kvcache(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1);
134134
}
135135

136136
static int test_sdpa_1()
@@ -255,7 +255,7 @@ static int test_sdpa_int8_self_kvcache_decode(const ncnn::Mat& a)
255255
std::vector<ncnn::Mat> weights(0);
256256

257257
std::vector<ncnn::Mat> as(3);
258-
as[0] = RandomMat(embed_dim, 1, a.c);
258+
as[0] = RandomMat(embed_dim, cur_seqlen, a.c);
259259
as[1] = RandomMat(embed_dim, past_seqlen, a.c);
260260
as[2] = RandomMat(embed_dim, past_seqlen, a.c);
261261

@@ -277,10 +277,10 @@ static int test_sdpa_3()
277277
|| test_sdpa_int8_cross_kvcache(RandomMat(26, 64, 8), RandomMat(26, 61, 8), RandomMat(18, 61, 8), 1)
278278
|| test_sdpa_int8_cross_kvcache(RandomMat(64, 128, 12), RandomMat(64, 128, 2), RandomMat(64, 128, 2), 0)
279279
|| test_sdpa_int8_cross_kvcache(RandomMat(48, 122, 12), RandomMat(64, 127, 2), RandomMat(64, 127, 2), 1)
280-
|| test_sdpa_int8_cross_kvcache(RandomMat(44, 128, 4), RandomMat(44, 123, 4), RandomMat(55, 123, 4), 0, 1.f)
281-
|| test_sdpa_int8_cross_kvcache(RandomMat(12, 127, 4), RandomMat(12, 127, 4), RandomMat(55, 127, 4), 1, 1.f)
282-
|| test_sdpa_int8_cross_kvcache(RandomMat(28, 17, 15), RandomMat(28, 127, 5), RandomMat(32, 127, 5), 0, 0.1f)
283-
|| test_sdpa_int8_cross_kvcache(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1, -0.4f);
280+
|| test_sdpa_int8_cross_kvcache(RandomMat(44, 128, 4), RandomMat(44, 123, 4), RandomMat(55, 123, 4), 0)
281+
|| test_sdpa_int8_cross_kvcache(RandomMat(12, 127, 4), RandomMat(12, 127, 4), RandomMat(55, 127, 4), 1)
282+
|| test_sdpa_int8_cross_kvcache(RandomMat(28, 17, 15), RandomMat(28, 127, 5), RandomMat(32, 127, 5), 0)
283+
|| test_sdpa_int8_cross_kvcache(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1);
284284
}
285285

286286
static int test_sdpa_4()

0 commit comments

Comments
 (0)