Skip to content

Commit 069c5e6

Browse files
committed
add fast_mpd
1 parent 67820b3 commit 069c5e6

File tree

5 files changed

+760
-2
lines changed

5 files changed

+760
-2
lines changed

configs/nsf_hifigan_fast.yaml

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# preprocessing
2+
base_config:
3+
- configs/base_hifi.yaml
4+
5+
data_input_path: []
6+
data_out_path: []
7+
val_num: 5
8+
9+
pe: 'parselmouth' # 'parselmouth' or 'harvest'
10+
f0_min: 65
11+
f0_max: 1100
12+
13+
aug_min: 0.9
14+
aug_max: 1.4
15+
aug_num: 1
16+
key_aug: false
17+
key_aug_prob: 0.5
18+
19+
pc_aug: false # pc-nsf training method
20+
pc_aug_rate: 0.4
21+
pc_aug_key: 12
22+
23+
use_stftloss: true
24+
loss_fft_sizes: [2048, 2048, 4096, 1024, 512, 256, 128,1024, 2048, 512]
25+
loss_hop_sizes: [512, 240, 480, 100, 50, 25, 12,120, 240, 50]
26+
loss_win_lengths: [2048, 1200, 2400, 480, 240, 120, 60,600, 1200, 240]
27+
lab_aux_melloss: 45
28+
lab_aux_stftloss: 2.5
29+
30+
raw_data_dir: []
31+
binary_data_dir: null
32+
binarization_args:
33+
num_workers: 8
34+
shuffle: true
35+
36+
DataIndexPath: data
37+
valid_set_name: valid
38+
train_set_name: train
39+
40+
41+
volume_aug: true
42+
volume_aug_prob: 0.5
43+
44+
45+
mel_vmin: -6. #-6.
46+
mel_vmax: 1.5
47+
48+
49+
audio_sample_rate: 44100
50+
audio_num_mel_bins: 128
51+
hop_size: 512 # Hop size.
52+
fft_size: 2048 # FFT size.
53+
win_size: 2048 # FFT size.
54+
fmin: 40
55+
fmax: 16000
56+
fmax_for_loss: null
57+
crop_mel_frames: 32
58+
59+
60+
61+
# global constants
62+
63+
64+
# neural networks
65+
66+
67+
#model_cls: training.nsf_HiFigan_task.nsf_HiFigan
68+
model_args:
69+
mini_nsf: true
70+
noise_sigma: 0.0
71+
upsample_rates: [ 8, 8, 2, 2, 2 ]
72+
upsample_kernel_sizes: [ 16,16, 4, 4, 4 ]
73+
upsample_initial_channel: 512
74+
resblock_kernel_sizes: [ 3,7,11 ]
75+
resblock_dilation_sizes: [ [ 1,3,5 ], [ 1,3,5 ], [ 1,3,5 ] ]
76+
discriminator_periods: [ 2, 3, 5, 7, 11]
77+
fast_mpd_strides: [4, 4, 4]
78+
fast_mpd_kernel_size: 11
79+
resblock: "1"
80+
81+
# training
82+
83+
task_cls: training.nsf_HiFigan_fast_task.nsf_HiFigan
84+
85+
86+
#sort_by_len: true
87+
#optimizer_args:
88+
# optimizer_cls: torch.optim.AdamW
89+
# lr: 0.0001
90+
# beta1: 0.9
91+
# beta2: 0.98
92+
# weight_decay: 0
93+
#lab_aux_loss: 0.5
94+
discriminate_optimizer_args:
95+
optimizer_cls: modules.optimizer.muon.Muon_AdamW
96+
lr: 0.0002
97+
muon_args:
98+
weight_decay: 0.03
99+
adamw_args:
100+
weight_decay: 0.0
101+
verbose: false
102+
103+
generater_optimizer_args:
104+
optimizer_cls: modules.optimizer.muon.Muon_AdamW
105+
lr: 0.0002
106+
muon_args:
107+
weight_decay: 0.03
108+
adamw_args:
109+
weight_decay: 0.0
110+
verbose: false
111+
112+
lr_scheduler_args:
113+
scheduler_cls: lr_scheduler.scheduler.WarmupLR
114+
warmup_steps: 5000
115+
min_lr: 0.00001
116+
117+
clip_grad_norm: 1
118+
accumulate_grad_batches: 1
119+
sampler_frame_count_grid: 6
120+
ds_workers: 4
121+
dataloader_prefetch_factor: 2
122+
123+
batch_size: 16
124+
125+
126+
127+
num_valid_plots: 100
128+
log_interval: 100
129+
num_sanity_val_steps: 2 # steps of validation at the beginning
130+
val_check_interval: 2000
131+
num_ckpt_keep: 5
132+
max_updates: 1000000
133+
permanent_ckpt_start: 200000
134+
permanent_ckpt_interval: 40000
135+
136+
###########
137+
# pytorch lightning
138+
# Read https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api for possible values
139+
###########
140+
pl_trainer_accelerator: 'auto'
141+
pl_trainer_devices: 'auto'
142+
pl_trainer_precision: '32-true'
143+
#pl_trainer_precision: 'bf16' #please do not use bf 16
144+
pl_trainer_num_nodes: 1
145+
pl_trainer_strategy:
146+
name: auto
147+
process_group_backend: nccl
148+
find_unused_parameters: true
149+
nccl_p2p: true
150+
seed: 114514
151+
152+
###########
153+
# finetune
154+
###########
155+
156+
finetune_enabled: false
157+
finetune_ckpt_path: ''
158+
finetune_ignored_params: []
159+
finetune_strict_shapes: true
160+
161+
freezing_enabled: false
162+
frozen_params: []

configs/nsf_hifigan_mrd.yaml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,19 @@ task_cls: training.nsf_HiFigan_mrd_task.nsf_HiFigan
9292
discriminate_optimizer_args:
9393
optimizer_cls: modules.optimizer.muon.Muon_AdamW
9494
lr: 0.0002
95-
weight_decay: 0
95+
muon_args:
96+
weight_decay: 0.03
97+
adamw_args:
98+
weight_decay: 0.0
9699
verbose: false
97100

98101
generater_optimizer_args:
99102
optimizer_cls: modules.optimizer.muon.Muon_AdamW
100103
lr: 0.0002
101-
weight_decay: 0
104+
muon_args:
105+
weight_decay: 0.03
106+
adamw_args:
107+
weight_decay: 0.0
102108
verbose: false
103109

104110
lr_scheduler_args:

modules/fast_D/__init__.py

Whitespace-only changes.

modules/fast_D/discriminator.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn.functional as F
4+
import torch.nn as nn
5+
6+
7+
def combine_frames(x, n):
8+
B, L, C = x.shape
9+
num_groups = L // n
10+
if num_groups == 0:
11+
return torch.empty(B, 0, n * C, device=x.device, dtype=x.dtype)
12+
x = x[:, :num_groups * n, :].reshape(B, num_groups, n * C)
13+
return x
14+
15+
16+
class Transpose(nn.Module):
17+
def __init__(self, dims):
18+
super().__init__()
19+
assert len(dims) == 2, 'dims must be a tuple of two dimensions'
20+
self.dims = dims
21+
22+
def forward(self, x):
23+
return x.transpose(*self.dims)
24+
25+
26+
class LeakyHardFunction(torch.autograd.Function):
27+
@staticmethod
28+
def forward(ctx, x, min_val, max_val, leak_slope):
29+
if not (min_val < max_val):
30+
raise ValueError("min_val must be < max_val")
31+
if leak_slope < 0:
32+
raise ValueError("leak_slope must be >= 0")
33+
ctx.min_val = min_val
34+
ctx.max_val = max_val
35+
ctx.leak_slope = leak_slope
36+
below_mask = x < min_val
37+
any_below = torch.any(below_mask)
38+
if any_below:
39+
x[below_mask] = leak_slope * x[below_mask] + (1 - leak_slope) * min_val
40+
above_mask = x > max_val
41+
any_above = torch.any(above_mask)
42+
if any_above:
43+
x[above_mask] = leak_slope * x[above_mask] + (1 - leak_slope) * max_val
44+
if any_below or any_above:
45+
ctx.save_for_backward(below_mask | above_mask)
46+
return x
47+
48+
@staticmethod
49+
def backward(ctx, grad_output):
50+
if len(ctx.saved_tensors) > 0:
51+
mask, = ctx.saved_tensors
52+
grad_output[mask] *= ctx.leak_slope
53+
return grad_output, None, None, None
54+
55+
56+
class ATanGLU(nn.Module):
57+
# ArcTan-Applies the gated linear unit function.
58+
def __init__(self, dim=-1, hard_limit=False):
59+
super().__init__()
60+
self.dim = dim
61+
self.hard_limit = hard_limit
62+
63+
def forward(self, x):
64+
if self.hard_limit:
65+
x = LeakyHardFunction.apply(x, -100, 100, 0.01)
66+
# out, gate = x.chunk(2, dim=self.dim)
67+
# Using torch.split instead of chunk for ONNX export compatibility.
68+
out, gate = torch.split(x, x.size(self.dim) // 2, dim=self.dim)
69+
return out * torch.atan(gate)
70+
71+
72+
class LYNXNet2Block(nn.Module):
73+
def __init__(self, dim, expansion_factor, kernel_size=31, dropout=0.):
74+
super().__init__()
75+
inner_dim = int(dim * expansion_factor)
76+
if float(dropout) > 0.:
77+
_dropout = nn.Dropout(dropout)
78+
else:
79+
_dropout = nn.Identity()
80+
self.net = nn.Sequential(
81+
Transpose((1, 2)),
82+
nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim),
83+
Transpose((1, 2)),
84+
nn.Linear(dim, inner_dim * 2),
85+
ATanGLU(),
86+
nn.Linear(inner_dim, inner_dim * 2),
87+
ATanGLU(hard_limit=True),
88+
nn.Linear(inner_dim, dim),
89+
_dropout
90+
)
91+
92+
def forward(self, x):
93+
norm_x = F.rms_norm(x, (x.size(-1), ))
94+
x = x + self.net(norm_x)
95+
return x, norm_x
96+
97+
98+
class FastPD(torch.nn.Module):
99+
def __init__(self, period, init_channel=8, strides=[4, 4, 4], kernel_size=11):
100+
super(FastPD, self).__init__()
101+
self.period = period
102+
self.strides = strides
103+
self.pre = nn.Linear(1, init_channel)
104+
self.residual_layers = nn.ModuleList(
105+
[
106+
LYNXNet2Block(
107+
dim=init_channel * np.prod(strides[: i + 1]),
108+
expansion_factor=1,
109+
kernel_size=kernel_size,
110+
dropout=0
111+
)
112+
for i in range(len(strides))
113+
]
114+
)
115+
self.post = nn.Linear(init_channel * np.prod(strides), 1)
116+
117+
def forward(self, x):
118+
fmap = []
119+
120+
# 1d to 2d
121+
b, _, t = x.shape
122+
if t % self.period != 0: # pad first
123+
n_pad = self.period - (t % self.period)
124+
x = F.pad(x, (0, n_pad), "reflect")
125+
t = t + n_pad
126+
x = x.view(b, 1, t // self.period, self.period)
127+
x = x.permute(0, 3, 2, 1).reshape(b * self.period, t // self.period, 1)
128+
129+
x = self.pre(x)
130+
for i, layer in enumerate(self.residual_layers):
131+
if self.strides[i] > 1:
132+
x = combine_frames(x, self.strides[i])
133+
x, norm_x = layer(x)
134+
if i > 0:
135+
fmap.append(norm_x.reshape(b, -1))
136+
x = self.post(F.rms_norm(x, (x.size(-1), )))
137+
x = x.reshape(b, -1)
138+
139+
return x, fmap
140+
141+
142+
class FastMPD(torch.nn.Module):
143+
def __init__(self,periods=None, init_channel=8, strides=[1, 2, 4, 4, 2], kernel_size=31):
144+
super(FastMPD, self).__init__()
145+
self.periods = periods if periods is not None else [2, 3, 5, 7, 11]
146+
self.discriminators = nn.ModuleList()
147+
for period in self.periods:
148+
self.discriminators.append(
149+
FastPD(period, init_channel=init_channel, strides=strides, kernel_size=kernel_size))
150+
151+
def forward(self, y,):
152+
y_d_rs = []
153+
fmap_rs = []
154+
155+
for i, d in enumerate(self.discriminators):
156+
y_d_r, fmap_r = d(y)
157+
y_d_rs.append(y_d_r)
158+
fmap_rs.append(fmap_r)
159+
160+
return y_d_rs, fmap_rs

0 commit comments

Comments
 (0)