|
3 | 3 | import torch.nn.functional as F |
4 | 4 | import torch.nn as nn |
5 | 5 |
|
6 | | - |
7 | | -def combine_frames(x, n): |
8 | | - B, L, C = x.shape |
9 | | - num_groups = L // n |
10 | | - if num_groups == 0: |
11 | | - return torch.empty(B, 0, n * C, device=x.device, dtype=x.dtype) |
12 | | - x = x[:, :num_groups * n, :].reshape(B, num_groups, n * C) |
13 | | - return x |
14 | | - |
15 | 6 |
|
16 | 7 | class Transpose(nn.Module): |
17 | 8 | def __init__(self, dims): |
@@ -116,30 +107,26 @@ def __init__(self, period, init_channel=8, strides=[4, 4, 4], kernel_size=11): |
116 | 107 | def forward(self, x): |
117 | 108 | fmap = [] |
118 | 109 |
|
119 | | - # 1d to 2d |
120 | 110 | b, _, t = x.shape |
121 | | - if t % self.period != 0: # pad first |
122 | | - n_pad = self.period - (t % self.period) |
123 | | - x = F.pad(x, (0, n_pad), "reflect") |
124 | | - t = t + n_pad |
125 | | - x = x.view(b, 1, t // self.period, self.period) |
126 | | - x = x.permute(0, 3, 2, 1).reshape(b * self.period, t // self.period, 1) |
| 111 | + n = self.period * np.prod(self.strides) |
| 112 | + x = x[:, :, : (t // n) * n].view(b, -1, self.period) |
| 113 | + x = x.transpose(1, 2).reshape(b * self.period, -1, 1) |
127 | 114 |
|
128 | 115 | x = self.pre(x) |
129 | 116 | for i, layer in enumerate(self.residual_layers): |
130 | 117 | if self.strides[i] > 1: |
131 | | - x = combine_frames(x, self.strides[i]) |
| 118 | + x = x.view(b, -1, x.size(2) * self.strides[i]) |
132 | 119 | x, norm_x = layer(x) |
133 | 120 | if i > 0: |
134 | | - fmap.append(norm_x.reshape(b, -1)) |
| 121 | + fmap.append(norm_x.view(b, -1)) |
135 | 122 | x = self.post(F.rms_norm(x, (x.size(-1), ))) |
136 | | - x = x.reshape(b, -1) |
| 123 | + x = x.view(b, -1) |
137 | 124 |
|
138 | 125 | return x, fmap |
139 | 126 |
|
140 | 127 |
|
141 | 128 | class FastMPD(torch.nn.Module): |
142 | | - def __init__(self,periods=None, init_channel=8, strides=[1, 2, 4, 4, 2], kernel_size=31): |
| 129 | + def __init__(self,periods=None, init_channel=8, strides=[4, 4, 4], kernel_size=11): |
143 | 130 | super(FastMPD, self).__init__() |
144 | 131 | self.periods = periods if periods is not None else [2, 3, 5, 7, 11] |
145 | 132 | self.discriminators = nn.ModuleList() |
|
0 commit comments