Skip to content

Commit 7ef4d64

Browse files
committed
optimize
1 parent 76ce4fe commit 7ef4d64

File tree

1 file changed

+7
-20
lines changed

1 file changed

+7
-20
lines changed

modules/fast_D/discriminator.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,6 @@
33
import torch.nn.functional as F
44
import torch.nn as nn
55

6-
7-
def combine_frames(x, n):
8-
B, L, C = x.shape
9-
num_groups = L // n
10-
if num_groups == 0:
11-
return torch.empty(B, 0, n * C, device=x.device, dtype=x.dtype)
12-
x = x[:, :num_groups * n, :].reshape(B, num_groups, n * C)
13-
return x
14-
156

167
class Transpose(nn.Module):
178
def __init__(self, dims):
@@ -116,30 +107,26 @@ def __init__(self, period, init_channel=8, strides=[4, 4, 4], kernel_size=11):
116107
def forward(self, x):
117108
fmap = []
118109

119-
# 1d to 2d
120110
b, _, t = x.shape
121-
if t % self.period != 0: # pad first
122-
n_pad = self.period - (t % self.period)
123-
x = F.pad(x, (0, n_pad), "reflect")
124-
t = t + n_pad
125-
x = x.view(b, 1, t // self.period, self.period)
126-
x = x.permute(0, 3, 2, 1).reshape(b * self.period, t // self.period, 1)
111+
n = self.period * np.prod(self.strides)
112+
x = x[:, :, : (t // n) * n].view(b, -1, self.period)
113+
x = x.transpose(1, 2).reshape(b * self.period, -1, 1)
127114

128115
x = self.pre(x)
129116
for i, layer in enumerate(self.residual_layers):
130117
if self.strides[i] > 1:
131-
x = combine_frames(x, self.strides[i])
118+
x = x.view(b, -1, x.size(2) * self.strides[i])
132119
x, norm_x = layer(x)
133120
if i > 0:
134-
fmap.append(norm_x.reshape(b, -1))
121+
fmap.append(norm_x.view(b, -1))
135122
x = self.post(F.rms_norm(x, (x.size(-1), )))
136-
x = x.reshape(b, -1)
123+
x = x.view(b, -1)
137124

138125
return x, fmap
139126

140127

141128
class FastMPD(torch.nn.Module):
142-
def __init__(self,periods=None, init_channel=8, strides=[1, 2, 4, 4, 2], kernel_size=31):
129+
def __init__(self,periods=None, init_channel=8, strides=[4, 4, 4], kernel_size=11):
143130
super(FastMPD, self).__init__()
144131
self.periods = periods if periods is not None else [2, 3, 5, 7, 11]
145132
self.discriminators = nn.ModuleList()

0 commit comments

Comments
 (0)