@@ -54,6 +54,23 @@ pnnx.Output output 1 0 out
5454 }
5555};
5656
57+ class fuse_rmsnorm_pass_2 : public fuse_rmsnorm_pass
58+ {
59+ public:
60+ const char * match_pattern_graph () const
61+ {
62+ return R"PNNXIR( 7767517
63+ 6 5
64+ pnnx.Input input 0 1 input
65+ pnnx.Attribute op_0 0 1 weight @data #weight=(%c)f32
66+ pnnx.Expression op_1 1 1 input sq expr=pow(@0,2)
67+ torch.mean op_2 1 1 sq sqmean dim=(-1) keepdim=True
68+ pnnx.Expression op_3 3 1 input sqmean weight out expr=mul(mul(@0,rsqrt(add(@1,%eps))),@2)
69+ pnnx.Output output 1 0 out
70+ )PNNXIR" ;
71+ }
72+ };
73+
5774class fuse_rmsnorm_pass_without_gamma : public GraphRewriterPass
5875{
5976public:
@@ -140,15 +157,17 @@ void fuse_rmsnorm(Graph& graph)
140157{
141158 fuse_rmsnorm_pass a;
142159 fuse_rmsnorm_pass_1 a1;
143- fuse_rmsnorm_pass_without_gamma a2;
144- fuse_rmsnorm_pass_without_gamma_1 a3;
160+ fuse_rmsnorm_pass_2 a2;
161+ fuse_rmsnorm_pass_without_gamma g;
162+ fuse_rmsnorm_pass_without_gamma_1 g1;
145163 fuse_rmsnorm_pass_onnx b;
146164 int opindex = 0 ;
147165
148166 pnnx_graph_rewrite (graph, &a, opindex);
149167 pnnx_graph_rewrite (graph, &a1, opindex);
150168 pnnx_graph_rewrite (graph, &a2, opindex);
151- pnnx_graph_rewrite (graph, &a3, opindex);
169+ pnnx_graph_rewrite (graph, &g, opindex);
170+ pnnx_graph_rewrite (graph, &g1, opindex);
152171 pnnx_graph_rewrite (graph, &b, opindex);
153172}
154173
0 commit comments