Skip to content

Commit 4963895

Browse files
committed
pnnx fuse more rmsnorm pattern
1 parent 559a8e8 commit 4963895

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

tools/pnnx/src/pass_level5/fuse_rmsnorm.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,23 @@ pnnx.Output output 1 0 out
5454
}
5555
};
5656

57+
class fuse_rmsnorm_pass_2 : public fuse_rmsnorm_pass
58+
{
59+
public:
60+
const char* match_pattern_graph() const
61+
{
62+
return R"PNNXIR(7767517
63+
6 5
64+
pnnx.Input input 0 1 input
65+
pnnx.Attribute op_0 0 1 weight @data #weight=(%c)f32
66+
pnnx.Expression op_1 1 1 input sq expr=pow(@0,2)
67+
torch.mean op_2 1 1 sq sqmean dim=(-1) keepdim=True
68+
pnnx.Expression op_3 3 1 input sqmean weight out expr=mul(mul(@0,rsqrt(add(@1,%eps))),@2)
69+
pnnx.Output output 1 0 out
70+
)PNNXIR";
71+
}
72+
};
73+
5774
class fuse_rmsnorm_pass_without_gamma : public GraphRewriterPass
5875
{
5976
public:
@@ -140,15 +157,17 @@ void fuse_rmsnorm(Graph& graph)
140157
{
141158
fuse_rmsnorm_pass a;
142159
fuse_rmsnorm_pass_1 a1;
143-
fuse_rmsnorm_pass_without_gamma a2;
144-
fuse_rmsnorm_pass_without_gamma_1 a3;
160+
fuse_rmsnorm_pass_2 a2;
161+
fuse_rmsnorm_pass_without_gamma g;
162+
fuse_rmsnorm_pass_without_gamma_1 g1;
145163
fuse_rmsnorm_pass_onnx b;
146164
int opindex = 0;
147165

148166
pnnx_graph_rewrite(graph, &a, opindex);
149167
pnnx_graph_rewrite(graph, &a1, opindex);
150168
pnnx_graph_rewrite(graph, &a2, opindex);
151-
pnnx_graph_rewrite(graph, &a3, opindex);
169+
pnnx_graph_rewrite(graph, &g, opindex);
170+
pnnx_graph_rewrite(graph, &g1, opindex);
152171
pnnx_graph_rewrite(graph, &b, opindex);
153172
}
154173

0 commit comments

Comments
 (0)