@@ -34,23 +34,22 @@ def forward(ctx, x, min_val, max_val, leak_slope):
3434 ctx .max_val = max_val
3535 ctx .leak_slope = leak_slope
3636 below_mask = x < min_val
37- any_below = torch .any (below_mask )
38- if any_below :
39- x [below_mask ] = leak_slope * x [below_mask ] + (1 - leak_slope ) * min_val
4037 above_mask = x > max_val
41- any_above = torch .any (above_mask )
42- if any_above :
43- x [above_mask ] = leak_slope * x [above_mask ] + (1 - leak_slope ) * max_val
44- if any_below or any_above :
45- ctx .save_for_backward (below_mask | above_mask )
38+ mask = below_mask | above_mask
39+ if mask .any ().item ():
40+ ctx .save_for_backward (mask )
41+ below_val = leak_slope * x + (1.0 - leak_slope ) * min_val
42+ above_val = leak_slope * x + (1.0 - leak_slope ) * max_val
43+ x = torch .where (below_mask , below_val , x )
44+ x = torch .where (above_mask , above_val , x )
4645 return x
4746
4847 @staticmethod
49- def backward (ctx , grad_output ):
48+ def backward (ctx , grad ):
5049 if len (ctx .saved_tensors ) > 0 :
5150 mask , = ctx .saved_tensors
52- grad_output [ mask ] *= ctx .leak_slope
53- return grad_output , None , None , None
51+ grad = torch . where ( mask , grad * ctx .leak_slope , grad )
52+ return grad , None , None , None
5453
5554
5655class ATanGLU (nn .Module ):
@@ -82,7 +81,7 @@ def __init__(self, dim, expansion_factor, kernel_size=31, dropout=0.):
8281 nn .Conv1d (dim , dim , kernel_size = kernel_size , padding = kernel_size // 2 , groups = dim ),
8382 Transpose ((1 , 2 )),
8483 nn .Linear (dim , inner_dim * 2 ),
85- ATanGLU (),
84+ ATanGLU (hard_limit = True ),
8685 nn .Linear (inner_dim , inner_dim * 2 ),
8786 ATanGLU (hard_limit = True ),
8887 nn .Linear (inner_dim , dim ),
0 commit comments