Skip to content

Commit 76ce4fe

Browse files
committed
optimize
1 parent 069c5e6 commit 76ce4fe

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

modules/fast_D/discriminator.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,22 @@ def forward(ctx, x, min_val, max_val, leak_slope):
3434
ctx.max_val = max_val
3535
ctx.leak_slope = leak_slope
3636
below_mask = x < min_val
37-
any_below = torch.any(below_mask)
38-
if any_below:
39-
x[below_mask] = leak_slope * x[below_mask] + (1 - leak_slope) * min_val
4037
above_mask = x > max_val
41-
any_above = torch.any(above_mask)
42-
if any_above:
43-
x[above_mask] = leak_slope * x[above_mask] + (1 - leak_slope) * max_val
44-
if any_below or any_above:
45-
ctx.save_for_backward(below_mask | above_mask)
38+
mask = below_mask | above_mask
39+
if mask.any().item():
40+
ctx.save_for_backward(mask)
41+
below_val = leak_slope * x + (1.0 - leak_slope) * min_val
42+
above_val = leak_slope * x + (1.0 - leak_slope) * max_val
43+
x = torch.where(below_mask, below_val, x)
44+
x = torch.where(above_mask, above_val, x)
4645
return x
4746

4847
@staticmethod
49-
def backward(ctx, grad_output):
48+
def backward(ctx, grad):
5049
if len(ctx.saved_tensors) > 0:
5150
mask, = ctx.saved_tensors
52-
grad_output[mask] *= ctx.leak_slope
53-
return grad_output, None, None, None
51+
grad = torch.where(mask, grad * ctx.leak_slope, grad)
52+
return grad, None, None, None
5453

5554

5655
class ATanGLU(nn.Module):
@@ -82,7 +81,7 @@ def __init__(self, dim, expansion_factor, kernel_size=31, dropout=0.):
8281
nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim),
8382
Transpose((1, 2)),
8483
nn.Linear(dim, inner_dim * 2),
85-
ATanGLU(),
84+
ATanGLU(hard_limit=True),
8685
nn.Linear(inner_dim, inner_dim * 2),
8786
ATanGLU(hard_limit=True),
8887
nn.Linear(inner_dim, dim),

0 commit comments

Comments
 (0)