Skip to content

Commit 55e32b6

Browse files
committed
fuse t5 layernorm without gamma
1 parent e59c680 commit 55e32b6

File tree

2 files changed

+112
-7
lines changed

2 files changed

+112
-7
lines changed

tools/pnnx/src/pass_level5/fuse_rmsnorm.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,71 @@ pnnx.Output output 1 0 out
5454
}
5555
};
5656

57+
class fuse_rmsnorm_pass_without_gamma : public GraphRewriterPass
58+
{
59+
public:
60+
const char* match_pattern_graph() const
61+
{
62+
return R"PNNXIR(7767517
63+
5 4
64+
pnnx.Input input 0 1 input
65+
pnnx.Expression op_0 1 1 input sq expr=pow(@0,2)
66+
torch.mean op_1 1 1 sq sqmean dim=(-1) keepdim=True
67+
pnnx.Expression op_2 2 1 input sqmean out expr=mul(@0,rsqrt(add(@1,%eps)))
68+
pnnx.Output output 1 0 out
69+
)PNNXIR";
70+
}
71+
72+
const char* type_str() const
73+
{
74+
return "nn.RMSNorm";
75+
}
76+
77+
const char* name_str() const
78+
{
79+
return "t5ln";
80+
}
81+
82+
bool match(const std::map<std::string, const Operator*>& matched_operators, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& /*captured_attrs*/) const
83+
{
84+
const Operator* op_0 = matched_operators.at("op_0");
85+
const std::vector<int>& shape = op_0->inputs[0]->shape;
86+
if (shape.empty())
87+
{
88+
// unknown normalized_shape
89+
return false;
90+
}
91+
92+
return true;
93+
}
94+
95+
void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
96+
{
97+
const std::vector<int>& shape = op->inputs[0]->shape;
98+
const int c = shape[shape.size() - 1];
99+
100+
op->params["elementwise_affine"] = false;
101+
op->params["eps"] = captured_params.at("eps");
102+
op->params["normalized_shape"] = std::vector<int>{c};
103+
}
104+
};
105+
106+
class fuse_rmsnorm_pass_without_gamma_1 : public fuse_rmsnorm_pass_without_gamma
107+
{
108+
public:
109+
const char* match_pattern_graph() const
110+
{
111+
return R"PNNXIR(7767517
112+
5 4
113+
pnnx.Input input 0 1 input
114+
pnnx.Expression op_0 1 1 input sq expr=pow(@0,2)
115+
torch.mean op_1 1 1 sq sqmean dim=(-1) keepdim=True
116+
pnnx.Expression op_2 2 1 input sqmean out expr=mul(@0,reciprocal(sqrt(add(@1,%eps))))
117+
pnnx.Output output 1 0 out
118+
)PNNXIR";
119+
}
120+
};
121+
57122
class fuse_rmsnorm_pass_onnx : public fuse_rmsnorm_pass
58123
{
59124
public:
@@ -75,11 +140,15 @@ void fuse_rmsnorm(Graph& graph)
75140
{
76141
fuse_rmsnorm_pass a;
77142
fuse_rmsnorm_pass_1 a1;
143+
fuse_rmsnorm_pass_without_gamma a2;
144+
fuse_rmsnorm_pass_without_gamma_1 a3;
78145
fuse_rmsnorm_pass_onnx b;
79146
int opindex = 0;
80147

81148
pnnx_graph_rewrite(graph, &a, opindex);
82149
pnnx_graph_rewrite(graph, &a1, opindex);
150+
pnnx_graph_rewrite(graph, &a2, opindex);
151+
pnnx_graph_rewrite(graph, &a3, opindex);
83152
pnnx_graph_rewrite(graph, &b, opindex);
84153
}
85154

tools/pnnx/tests/test_F_rms_norm.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,32 @@
66
import torch.nn.functional as F
77
from packaging import version
88

9+
class T5LayerNorm(nn.Module):
10+
def __init__(self, hidden_size, eps=1e-6):
11+
super().__init__()
12+
self.weight = nn.Parameter(torch.rand(hidden_size))
13+
self.variance_epsilon = eps
14+
15+
def forward(self, hidden_states):
16+
input_dtype = hidden_states.dtype
17+
hidden_states = hidden_states.to(torch.float32)
18+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
19+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
20+
return self.weight * hidden_states.to(input_dtype)
21+
22+
class T5LayerNorm_without_gamma(nn.Module):
23+
def __init__(self, hidden_size, eps=1e-6):
24+
super().__init__()
25+
self.weight = nn.Parameter(torch.ones(hidden_size))
26+
self.variance_epsilon = eps
27+
28+
def forward(self, hidden_states):
29+
input_dtype = hidden_states.dtype
30+
hidden_states = hidden_states.to(torch.float32)
31+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
32+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
33+
return self.weight * hidden_states.to(input_dtype)
34+
935
class Model(nn.Module):
1036
def __init__(self):
1137
super(Model, self).__init__()
@@ -14,7 +40,10 @@ def __init__(self):
1440
self.w4 = nn.Parameter(torch.rand(12, 16))
1541
self.w5 = nn.Parameter(torch.rand(24))
1642

17-
def forward(self, x, y, z, w0, w1, w2):
43+
self.rmsnorm = T5LayerNorm(66)
44+
self.rmsnorm_2 = T5LayerNorm_without_gamma(66)
45+
46+
def forward(self, x, y, z, w0, w1, w2, x2):
1847
x = F.rms_norm(x, (24,), w0)
1948
x = F.rms_norm(x, (12,24), None)
2049
x = F.rms_norm(x, (24,), self.w3)
@@ -26,7 +55,10 @@ def forward(self, x, y, z, w0, w1, w2):
2655
z = F.rms_norm(z, (24,), w2)
2756
z = F.rms_norm(z, (12,16,24), None, eps=1e-2)
2857
z = F.rms_norm(z, (24,), self.w5)
29-
return x, y, z
58+
59+
x2 = self.rmsnorm(x2)
60+
x2 = self.rmsnorm_2(x2)
61+
return x, y, z, x2
3062

3163
def test():
3264
if version.parse(torch.__version__) < version.parse('2.4'):
@@ -42,22 +74,26 @@ def test():
4274
w0 = torch.rand(24)
4375
w1 = torch.rand(12, 16)
4476
w2 = torch.rand(24)
77+
x2 = torch.rand(3, 22, 66)
4578

46-
a0, a1, a2 = net(x, y, z, w0, w1, w2)
79+
a = net(x, y, z, w0, w1, w2, x2)
4780

4881
# export torchscript
49-
mod = torch.jit.trace(net, (x, y, z, w0, w1, w2))
82+
mod = torch.jit.trace(net, (x, y, z, w0, w1, w2, x2))
5083
mod.save("test_F_rms_norm.pt")
5184

5285
# torchscript to pnnx
5386
import os
54-
os.system("../src/pnnx test_F_rms_norm.pt inputshape=[1,12,24],[2,3,12,16],[1,10,12,16,24],[24],[12,16],[24]")
87+
os.system("../src/pnnx test_F_rms_norm.pt inputshape=[1,12,24],[2,3,12,16],[1,10,12,16,24],[24],[12,16],[24],[3,22,66]")
5588

5689
# pnnx inference
5790
import test_F_rms_norm_pnnx
58-
b0, b1, b2 = test_F_rms_norm_pnnx.test_inference()
91+
b = test_F_rms_norm_pnnx.test_inference()
5992

60-
return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2)
93+
for a0, b0 in zip(a, b):
94+
if not torch.equal(a0, b0):
95+
return False
96+
return True
6197

6298
if __name__ == "__main__":
6399
if test():

0 commit comments

Comments
 (0)