Skip to content

Commit 2161d19

Browse files
committed
fuse non interleaved rotary embed
1 parent 1113017 commit 2161d19

File tree

4 files changed

+120
-1
lines changed

4 files changed

+120
-1
lines changed

tools/pnnx/src/pass_ncnn/fuse_convert_rotaryembed.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,62 @@ namespace pnnx {
99

1010
namespace ncnn {
1111

12+
class fuse_rotaryembed_pass : public GraphRewriterPass
13+
{
14+
public:
15+
const char* match_pattern_graph() const
16+
{
17+
return R"PNNXIR(7767517
18+
8 8
19+
pnnx.Input input_0 0 1 input
20+
pnnx.Input input_1 0 1 cos_cache
21+
pnnx.Input input_2 0 1 sin_cache
22+
torch.tensor_split op_0 1 2 input 19 20 dim=%split_dim indices=(%embed_dim_half)
23+
pnnx.Expression op_1 1 1 20 21 expr=neg(@0)
24+
torch.cat op_2 2 1 21 19 22 dim=%cat_dim
25+
pnnx.Expression op_3 4 1 input cos_cache 22 sin_cache out expr=add(mul(@0,@1),mul(@2,@3))
26+
pnnx.Output output 1 0 out
27+
)PNNXIR";
28+
}
29+
30+
const char* type_str() const
31+
{
32+
return "RotaryEmbed";
33+
}
34+
35+
const char* name_str() const
36+
{
37+
return "rope";
38+
}
39+
40+
bool match(const std::map<std::string, const Operator*>& matched_operators, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& /*captured_attrs*/) const
41+
{
42+
const Operand* input = matched_operators.at("op_0")->inputs[0];
43+
if (!input->shape.empty())
44+
{
45+
const int embed_dim = input->shape[input->shape.size() - 1];
46+
const int embed_dim_half = captured_params.at("embed_dim_half").i;
47+
if (embed_dim != embed_dim_half * 2)
48+
return false;
49+
}
50+
51+
const int split_dim = captured_params.at("split_dim").i;
52+
if (split_dim != 3 && split_dim != -1)
53+
return false;
54+
55+
const int cat_dim = captured_params.at("cat_dim").i;
56+
if (cat_dim != 3 && cat_dim != -1)
57+
return false;
58+
59+
return true;
60+
}
61+
62+
void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/) const
63+
{
64+
op->params["0"] = 0; // non-interleaved
65+
}
66+
};
67+
1268
class fuse_rotaryembed_pass_interleaved : public GraphRewriterPass
1369
{
1470
public:
@@ -72,9 +128,11 @@ pnnx.Output output 1 0 out
72128
void fuse_convert_rotaryembed(Graph& graph)
73129
{
74130
fuse_rotaryembed_pass_interleaved a;
131+
fuse_rotaryembed_pass b;
75132
int opindex = 0;
76133

77134
pnnx_graph_rewrite(graph, &a, opindex);
135+
pnnx_graph_rewrite(graph, &b, opindex);
78136
}
79137

80138
} // namespace ncnn

tools/pnnx/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ pnnx_add_test(transformers_mt5_attention)
401401
pnnx_add_test(transformers_openai_attention)
402402
pnnx_add_test(transformers_pegasus_attention)
403403
pnnx_add_test(transformers_prophetnet_attention)
404+
pnnx_add_test(transformers_qwen3_attention)
404405
pnnx_add_test(transformers_reformer_attention)
405406
pnnx_add_test(transformers_roberta_attention)
406407
pnnx_add_test(transformers_squeezebert_attention)

tools/pnnx/tests/test_transformers_deepseek_v3_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test():
4545

4646
# torchscript to pnnx
4747
import os
48-
os.system("../src/pnnx test_transformers_deepseek_v3_attention.pt inputshape=[3,16,192],[3,1,16,16] fp16=0")
48+
os.system("../src/pnnx test_transformers_deepseek_v3_attention.pt inputshape=[3,16,192],[3,1,16,16]")
4949

5050
# pnnx inference
5151
import test_transformers_deepseek_v3_attention_pnnx
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2025 Tencent
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
from packaging import version
8+
9+
if version.parse(torch.__version__) < version.parse('2.1'):
10+
exit(0)
11+
12+
from transformers import Qwen3Config
13+
from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention, Qwen3RotaryEmbedding
14+
15+
class Model(nn.Module):
16+
def __init__(self):
17+
super(Model, self).__init__()
18+
19+
config = Qwen3Config(hidden_size=192, num_attention_heads=16, num_key_value_heads=16, q_lora_rank=64, kv_lora_rank=128, attn_implementation='sdpa')
20+
self.rotary_emb = Qwen3RotaryEmbedding(config)
21+
self.attn0 = Qwen3Attention(config, layer_idx=1)
22+
23+
def forward(self, x, mask0):
24+
batch_size = x.size(0)
25+
seq_length = x.size(1)
26+
position_ids = torch.arange(seq_length, dtype=torch.long).unsqueeze(0).expand(batch_size, -1)
27+
position_embeddings = self.rotary_emb(x, position_ids)
28+
out0 = self.attn0(x, position_embeddings=position_embeddings, attention_mask=mask0, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, output_attentions=True)
29+
return out0[0]
30+
31+
def test():
32+
net = Model()
33+
net.eval()
34+
35+
torch.manual_seed(0)
36+
x = torch.rand(3, 16, 192)
37+
38+
mask0 = torch.rand(3, 1, 16, 16)
39+
40+
a = net(x, mask0)
41+
42+
# export torchscript
43+
mod = torch.jit.trace(net, (x, mask0))
44+
mod.save("test_transformers_qwen3_attention.pt")
45+
46+
# torchscript to pnnx
47+
import os
48+
os.system("../src/pnnx test_transformers_qwen3_attention.pt inputshape=[3,16,192],[3,1,16,16]")
49+
50+
# pnnx inference
51+
import test_transformers_qwen3_attention_pnnx
52+
b = test_transformers_qwen3_attention_pnnx.test_inference()
53+
54+
return torch.allclose(a, b, 1e-4, 1e-4)
55+
56+
if __name__ == "__main__":
57+
if test():
58+
exit(0)
59+
else:
60+
exit(1)

0 commit comments

Comments
 (0)