1- from ..tensor .operators import TensorBackend
21from ..tensor .functions import rand , zeros
32from .module import Module , Parameter
43from ..backends import fast_conv , fast_ops
5-
6- BACKEND = TensorBackend (fast_ops .FastOps )
4+ from . import init
75
86
97class Linear (Module ):
10- def __init__ (self , in_size , out_size ):
8+ def __init__ (self , in_size , out_size , initializer = init . kaiming_uniform ):
119 super ().__init__ ()
12-
13- # He initialization
14- scale = (2.0 / in_size ) ** 0.5
15- self .weights = Parameter (scale * rand ((in_size , out_size ), backend = BACKEND ))
16- self .bias = Parameter (zeros ((out_size ,), backend = BACKEND ))
10+ self .weights = Parameter (rand ((in_size , out_size )))
11+ initializer (self .weights .value , in_size )
12+ self .bias = Parameter (zeros ((out_size ,)))
1713 self .out_size = out_size
1814
1915 def forward (self , x ):
@@ -24,33 +20,25 @@ def forward(self, x):
2420
2521
2622class Conv1d (Module ):
27- def __init__ (self , in_channels , out_channels , kernel_width ):
23+ def __init__ (self , in_channels , out_channels , kernel_width , initializer = init . kaiming_uniform ):
2824 super ().__init__ ()
29-
30- # He initialization
25+ self .weights = Parameter (rand ((out_channels , in_channels , kernel_width )))
3126 fan_in = in_channels * kernel_width
32- scale = (2.0 / fan_in ) ** 0.5
33- self .weights = Parameter (
34- scale * rand ((out_channels , in_channels , kernel_width ), backend = BACKEND )
35- )
36- self .bias = Parameter (zeros ((1 , out_channels , 1 ), backend = BACKEND ))
27+ initializer (self .weights .value , fan_in )
28+ self .bias = Parameter (zeros ((1 , out_channels , 1 )))
3729
3830 def forward (self , input ):
3931 out = fast_conv .conv1d (input , self .weights .value ) + self .bias .value
4032 return out
4133
4234
4335class Conv2d (Module ):
44- def __init__ (self , in_channels , out_channels , kh , kw ):
36+ def __init__ (self , in_channels , out_channels , kh , kw , initializer = init . kaiming_uniform ):
4537 super ().__init__ ()
46-
47- # He initialization
38+ self .weights = Parameter (rand ((out_channels , in_channels , kh , kw )))
4839 fan_in = in_channels * kh * kw
49- scale = (2.0 / fan_in ) ** 0.5
50- self .weights = Parameter (
51- scale * rand ((out_channels , in_channels , kh , kw ), backend = BACKEND )
52- )
53- self .bias = Parameter (zeros ((out_channels , 1 , 1 ), backend = BACKEND ))
40+ initializer (self .weights .value , fan_in )
41+ self .bias = Parameter (zeros ((out_channels , 1 , 1 )))
5442
5543 def forward (self , input ):
5644 out = fast_conv .conv2d (input , self .weights .value ) + self .bias .value
0 commit comments