Skip to content

Commit 5845d8a

Browse files
committed
swap sin cos, add deepseek_v3 attention test
1 parent 7683a3c commit 5845d8a

File tree

7 files changed

+166
-5
lines changed

7 files changed

+166
-5
lines changed

src/layer/rotaryembed.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ int RotaryEmbed::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>&
2121
// assert bottom_blobs.size() == 3
2222

2323
const Mat& bottom_blob = bottom_blobs[0];
24-
const Mat& sin_cache = bottom_blobs[1];
25-
const Mat& cos_cache = bottom_blobs[2];
24+
const Mat& cos_cache = bottom_blobs[1];
25+
const Mat& sin_cache = bottom_blobs[2];
2626

2727
const int embed_dim = bottom_blob.w;
2828
const int seqlen = bottom_blob.h;
@@ -44,16 +44,16 @@ int RotaryEmbed::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>&
4444
if (interleaved)
4545
{
4646
const float* ptr = head.row(i);
47-
const float* sin_ptr = sin_cache.row(i);
4847
const float* cos_ptr = cos_cache.row(i);
48+
const float* sin_ptr = sin_cache.row(i);
4949
float* outptr = out_head.row(i);
5050

5151
for (int j = 0; j < embed_dim / 2; j++)
5252
{
5353
const float x1 = ptr[0];
5454
const float x2 = ptr[1];
55-
const float sin_val = *sin_ptr++;
5655
const float cos_val = *cos_ptr++;
56+
const float sin_val = *sin_ptr++;
5757
outptr[0] = x1 * cos_val - x2 * sin_val;
5858
outptr[1] = x1 * sin_val + x2 * cos_val;
5959
ptr += 2;
@@ -73,8 +73,8 @@ int RotaryEmbed::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>&
7373
{
7474
const float x1 = *ptr1++;
7575
const float x2 = *ptr2++;
76-
const float sin_val = *sin_ptr++;
7776
const float cos_val = *cos_ptr++;
77+
const float sin_val = *sin_ptr++;
7878
*outptr1++ = x1 * cos_val - x2 * sin_val;
7979
*outptr2++ = x1 * sin_val + x2 * cos_val;
8080
}

tools/pnnx/src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,7 @@ set(pnnx_pass_ncnn_SRCS
430430
pass_ncnn/eliminate_output.cpp
431431
pass_ncnn/expand_expression.cpp
432432
pass_ncnn/fuse_convert_shufflechannel_slice.cpp
433+
pass_ncnn/fuse_convert_rotaryembed.cpp
433434
pass_ncnn/insert_split.cpp
434435
pass_ncnn/chain_multi_output.cpp
435436
pass_ncnn/solve_batch_index.cpp

tools/pnnx/src/pass_ncnn.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "pass_ncnn/eliminate_output.h"
2424
#include "pass_ncnn/expand_expression.h"
2525
#include "pass_ncnn/fuse_convert_shufflechannel_slice.h"
26+
#include "pass_ncnn/fuse_convert_rotaryembed.h"
2627
#include "pass_ncnn/insert_split.h"
2728
#include "pass_ncnn/chain_multi_output.h"
2829
#include "pass_ncnn/solve_batch_index.h"
@@ -79,6 +80,8 @@ void pass_ncnn(Graph& g, const std::vector<std::string>& module_operators)
7980

8081
attribute_unpooling(g);
8182

83+
ncnn::fuse_convert_rotaryembed(g);
84+
8285
ncnn::expand_expression(g);
8386

8487
ncnn::chain_multi_output(g);
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright 2025 Tencent
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
#include "fuse_convert_rotaryembed.h"
5+
6+
#include "pass_level2.h"
7+
8+
namespace pnnx {
9+
10+
namespace ncnn {
11+
12+
class fuse_rotaryembed_pass_interleaved : public GraphRewriterPass
13+
{
14+
public:
15+
const char* match_pattern_graph() const
16+
{
17+
return R"PNNXIR(7767517
18+
11 11
19+
pnnx.Input input_0 0 1 input
20+
pnnx.Input input_1 0 1 cos_cache
21+
pnnx.Input input_2 0 1 sin_cache
22+
Tensor.reshape op_0 1 1 input 22 shape=(%batch,%num_heads,%seqlen,%embed_dim_half,2)
23+
torch.transpose op_1 1 1 22 23 dim0=%interleave_dim0 dim1=%interleave_dim1
24+
Tensor.reshape op_2 1 1 23 24 shape=(%batch,%num_heads,%seqlen,%embed_dim)
25+
torch.tensor_split op_3 1 2 24 28 29 dim=%split_dim indices=(%embed_dim_half)
26+
pnnx.Expression op_4 1 1 29 30 expr=neg(@0)
27+
torch.cat op_5 2 1 30 28 31 dim=%cat_dim
28+
pnnx.Expression op_6 4 1 24 cos_cache 31 sin_cache out expr=add(mul(@0,@1),mul(@2,@3))
29+
pnnx.Output output 1 0 out
30+
)PNNXIR";
31+
}
32+
33+
const char* type_str() const
34+
{
35+
return "RotaryEmbed";
36+
}
37+
38+
const char* name_str() const
39+
{
40+
return "rope";
41+
}
42+
43+
bool match(const std::map<std::string, Parameter>& captured_params) const
44+
{
45+
const int embed_dim_half = captured_params.at("embed_dim_half").i;
46+
const int embed_dim = captured_params.at("embed_dim").i;
47+
if (embed_dim != embed_dim_half * 2)
48+
return false;
49+
50+
const int interleave_dim0 = captured_params.at("interleave_dim0").i;
51+
const int interleave_dim1 = captured_params.at("interleave_dim1").i;
52+
if (!((interleave_dim0 == 4 && interleave_dim1 == 3) || (interleave_dim0 == 3 && interleave_dim1 == 4)))
53+
return false;
54+
55+
const int split_dim = captured_params.at("split_dim").i;
56+
if (split_dim != 3 && split_dim != -1)
57+
return false;
58+
59+
const int cat_dim = captured_params.at("cat_dim").i;
60+
if (cat_dim != 3 && cat_dim != -1)
61+
return false;
62+
63+
return true;
64+
}
65+
66+
void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/) const
67+
{
68+
op->params["0"] = 1; // interleaved
69+
}
70+
};
71+
72+
void fuse_convert_rotaryembed(Graph& graph)
73+
{
74+
fuse_rotaryembed_pass_interleaved a;
75+
int opindex = 0;
76+
77+
pnnx_graph_rewrite(graph, &a, opindex);
78+
}
79+
80+
} // namespace ncnn
81+
82+
} // namespace pnnx
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// Copyright 2025 Tencent
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
#include "ir.h"
5+
6+
namespace pnnx {
7+
8+
namespace ncnn {
9+
10+
void fuse_convert_rotaryembed(Graph& graph);
11+
12+
} // namespace ncnn
13+
14+
} // namespace pnnx

tools/pnnx/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ pnnx_add_test(transformers_clip_attention)
383383
pnnx_add_test(transformers_chinese_clip_attention)
384384
pnnx_add_test(transformers_ctrl_attention)
385385
pnnx_add_test(transformers_deberta_attention)
386+
pnnx_add_test(transformers_deepseek_v3_attention)
386387
pnnx_add_test(transformers_distilbert_attention)
387388
pnnx_add_test(transformers_electra_attention)
388389
pnnx_add_test(transformers_flaubert_attention)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2025 Tencent
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
from packaging import version
8+
9+
if version.parse(torch.__version__) < version.parse('2.1'):
10+
exit(0)
11+
12+
from transformers import DeepseekV3Config
13+
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3Attention, DeepseekV3RotaryEmbedding
14+
15+
class Model(nn.Module):
16+
def __init__(self):
17+
super(Model, self).__init__()
18+
19+
config = DeepseekV3Config(hidden_size=192, num_attention_heads=16, num_key_value_heads=16, q_lora_rank=64, kv_lora_rank=128, attn_implementation='sdpa')
20+
self.rotary_emb = DeepseekV3RotaryEmbedding(config)
21+
self.attn0 = DeepseekV3Attention(config, layer_idx=1)
22+
23+
def forward(self, x, mask0):
24+
batch_size = x.size(0)
25+
seq_length = x.size(1)
26+
position_ids = torch.arange(seq_length, dtype=torch.long).unsqueeze(0).expand(batch_size, -1)
27+
position_embeddings = self.rotary_emb(x, position_ids)
28+
out0 = self.attn0(x, position_embeddings=position_embeddings, attention_mask=mask0, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, output_attentions=True)
29+
return out0[0]
30+
31+
def test():
32+
net = Model()
33+
net.eval()
34+
35+
torch.manual_seed(0)
36+
x = torch.rand(3, 16, 192)
37+
38+
mask0 = torch.rand(3, 1, 16, 16)
39+
40+
a = net(x, mask0)
41+
42+
# export torchscript
43+
mod = torch.jit.trace(net, (x, mask0))
44+
mod.save("test_transformers_deepseek_v3_attention.pt")
45+
46+
# torchscript to pnnx
47+
import os
48+
os.system("../src/pnnx test_transformers_deepseek_v3_attention.pt inputshape=[3,16,192],[3,1,16,16] fp16=0")
49+
50+
# pnnx inference
51+
import test_transformers_deepseek_v3_attention_pnnx
52+
b = test_transformers_deepseek_v3_attention_pnnx.test_inference()
53+
54+
return torch.allclose(a, b, 1e-4, 1e-4)
55+
56+
if __name__ == "__main__":
57+
if test():
58+
exit(0)
59+
else:
60+
exit(1)

0 commit comments

Comments
 (0)