66import torch .nn .functional as F
77from packaging import version
88
9+ class T5LayerNorm (nn .Module ):
10+ def __init__ (self , hidden_size , eps = 1e-6 ):
11+ super ().__init__ ()
12+ self .weight = nn .Parameter (torch .rand (hidden_size ))
13+ self .variance_epsilon = eps
14+
15+ def forward (self , hidden_states ):
16+ input_dtype = hidden_states .dtype
17+ hidden_states = hidden_states .to (torch .float32 )
18+ variance = hidden_states .pow (2 ).mean (- 1 , keepdim = True )
19+ hidden_states = hidden_states * torch .rsqrt (variance + self .variance_epsilon )
20+ return self .weight * hidden_states .to (input_dtype )
21+
22+ class T5LayerNorm_without_gamma (nn .Module ):
23+ def __init__ (self , hidden_size , eps = 1e-6 ):
24+ super ().__init__ ()
25+ self .weight = nn .Parameter (torch .ones (hidden_size ))
26+ self .variance_epsilon = eps
27+
28+ def forward (self , hidden_states ):
29+ input_dtype = hidden_states .dtype
30+ hidden_states = hidden_states .to (torch .float32 )
31+ variance = hidden_states .pow (2 ).mean (- 1 , keepdim = True )
32+ hidden_states = hidden_states * torch .rsqrt (variance + self .variance_epsilon )
33+ return self .weight * hidden_states .to (input_dtype )
34+
935class Model (nn .Module ):
1036 def __init__ (self ):
1137 super (Model , self ).__init__ ()
@@ -14,7 +40,10 @@ def __init__(self):
1440 self .w4 = nn .Parameter (torch .rand (12 , 16 ))
1541 self .w5 = nn .Parameter (torch .rand (24 ))
1642
17- def forward (self , x , y , z , w0 , w1 , w2 ):
43+ self .rmsnorm = T5LayerNorm (66 )
44+ self .rmsnorm_2 = T5LayerNorm_without_gamma (66 )
45+
46+ def forward (self , x , y , z , w0 , w1 , w2 , x2 ):
1847 x = F .rms_norm (x , (24 ,), w0 )
1948 x = F .rms_norm (x , (12 ,24 ), None )
2049 x = F .rms_norm (x , (24 ,), self .w3 )
@@ -26,7 +55,10 @@ def forward(self, x, y, z, w0, w1, w2):
2655 z = F .rms_norm (z , (24 ,), w2 )
2756 z = F .rms_norm (z , (12 ,16 ,24 ), None , eps = 1e-2 )
2857 z = F .rms_norm (z , (24 ,), self .w5 )
29- return x , y , z
58+
59+ x2 = self .rmsnorm (x2 )
60+ x2 = self .rmsnorm_2 (x2 )
61+ return x , y , z , x2
3062
3163def test ():
3264 if version .parse (torch .__version__ ) < version .parse ('2.4' ):
@@ -42,22 +74,26 @@ def test():
4274 w0 = torch .rand (24 )
4375 w1 = torch .rand (12 , 16 )
4476 w2 = torch .rand (24 )
77+ x2 = torch .rand (3 , 22 , 66 )
4578
46- a0 , a1 , a2 = net (x , y , z , w0 , w1 , w2 )
79+ a = net (x , y , z , w0 , w1 , w2 , x2 )
4780
4881 # export torchscript
49- mod = torch .jit .trace (net , (x , y , z , w0 , w1 , w2 ))
82+ mod = torch .jit .trace (net , (x , y , z , w0 , w1 , w2 , x2 ))
5083 mod .save ("test_F_rms_norm.pt" )
5184
5285 # torchscript to pnnx
5386 import os
54- os .system ("../src/pnnx test_F_rms_norm.pt inputshape=[1,12,24],[2,3,12,16],[1,10,12,16,24],[24],[12,16],[24]" )
87+ os .system ("../src/pnnx test_F_rms_norm.pt inputshape=[1,12,24],[2,3,12,16],[1,10,12,16,24],[24],[12,16],[24],[3,22,66] " )
5588
5689 # pnnx inference
5790 import test_F_rms_norm_pnnx
58- b0 , b1 , b2 = test_F_rms_norm_pnnx .test_inference ()
91+ b = test_F_rms_norm_pnnx .test_inference ()
5992
60- return torch .equal (a0 , b0 ) and torch .equal (a1 , b1 ) and torch .equal (a2 , b2 )
93+ for a0 , b0 in zip (a , b ):
94+ if not torch .equal (a0 , b0 ):
95+ return False
96+ return True
6197
6298if __name__ == "__main__" :
6399 if test ():
0 commit comments