Skip to content

Commit 316fe8a

Browse files
committed
add nsf_univ and export ckpt code
1 parent 08c4e95 commit 316fe8a

File tree

3 files changed

+511
-0
lines changed

3 files changed

+511
-0
lines changed

export_ckpt.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import importlib
2+
import pathlib
3+
4+
import click
5+
import torch
6+
from tqdm import tqdm
7+
8+
from utils import get_latest_checkpoint_path
9+
from utils.config_utils import read_full_config
10+
11+
12+
13+
14+
@click.command(help='Train a SOME model')
15+
@click.option('--exp_name', required=True, metavar='EXP', help='Name of the experiment')
16+
@click.option('--save_path', required=True, metavar='EXP', help='Name of the experiment')
17+
@click.option('--work_dir', required=False, metavar='DIR', help='Directory to save the experiment')
18+
def train( exp_name,save_path, work_dir):
19+
20+
# print_config(config)
21+
if work_dir is None:
22+
work_dir = pathlib.Path(__file__).parent / 'experiments'
23+
else:
24+
work_dir = pathlib.Path(work_dir)
25+
work_dir = work_dir / exp_name
26+
assert not work_dir.exists() or work_dir.is_dir(), f'Path \'{work_dir}\' is not a directory.'
27+
work_dir.mkdir(parents=True, exist_ok=True)
28+
29+
30+
31+
ckp = {}
32+
33+
aaa2x = torch.load(get_latest_checkpoint_path(work_dir))['state_dict']
34+
for i in tqdm(aaa2x):
35+
i: str
36+
if 'generator.' in i:
37+
# print(i)
38+
ckp[i.replace('generator.', '')] = aaa2x[i]
39+
40+
torch.save({'generator': ckp}, save_path)
41+
42+
43+
44+
45+
46+
47+
if __name__ == '__main__':
48+
train()

modules/loss/nsf_univloss_msd.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
from modules.ddsp.loss import HybridLoss
6+
from modules.loss.stft_loss import warp_stft
7+
from utils.wav2mel import PitchAdjustableMelSpectrogram
8+
9+
10+
class nsf_univloss_msd(nn.Module):
11+
def __init__(self, config: dict):
12+
super().__init__()
13+
self.mel = PitchAdjustableMelSpectrogram(sample_rate=config['audio_sample_rate'],
14+
n_fft=config['fft_size'],
15+
win_length=config['win_size'],
16+
hop_length=config['hop_size'],
17+
f_min=config['fmin'],
18+
f_max=config['fmax_for_loss'],
19+
n_mels=config['audio_num_mel_bins'], )
20+
self.L1loss = nn.L1Loss()
21+
self.labauxloss = config.get('lab_aux_loss', 45)
22+
self.labddsploss=config.get('lab_ddsp_loss', 2)
23+
# self.stft=warp_stft({'fft_sizes':[1024, 2048, 512,],'hop_sizes':[120, 240, 50,],'win_lengths':[600, 1200, 240,]})
24+
25+
# self.stft = warp_stft(
26+
# {'fft_sizes': [2048, 2048, 4096, 1024, 512, 256, 128], 'hop_sizes': [512, 240, 480, 100, 50, 25, 12],
27+
# 'win_lengths': [2048, 1200, 2400, 480, 240, 120, 60]})
28+
self.stft = warp_stft({'fft_sizes': config['loss_fft_sizes'], 'hop_sizes': config['loss_hop_sizes'],
29+
'win_lengths': config['loss_win_lengths']})
30+
31+
self.deuv = config.get('detuv', 2000)
32+
33+
# self.ddsploss = HybridLoss(block_size=config['hop_size'], fft_min=config['ddsp_fftmin'],
34+
# fft_max=config['ddsp_fftmax'], n_scale=config['ddsp_nscale'],
35+
# lambda_uv=config['ddsp_lambdauv'], device='cuda')
36+
# fft_sizes = [2048, 4096, 1024, 512, 256, 128],
37+
# hop_sizes = [240, 480, 100, 50, 25, 12],
38+
# win_lengths = [1200, 2400, 480, 240, 120, 60]
39+
40+
def discriminator_loss(self, disc_real_outputs, disc_generated_outputs):
41+
loss = 0
42+
rlosses = 0
43+
glosses = 0
44+
r_losses = []
45+
g_losses = []
46+
47+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
48+
r_loss = torch.mean((1 - dr) ** 2)
49+
g_loss = torch.mean(dg ** 2)
50+
loss += r_loss + g_loss
51+
rlosses += r_loss.item()
52+
glosses += g_loss.item()
53+
r_losses.append(r_loss.item())
54+
g_losses.append(g_loss.item())
55+
56+
return loss, rlosses, glosses, r_losses, g_losses
57+
58+
def Dloss(self, Dfake, Dtrue):
59+
60+
(Fmrd_out, _), (Fmpd_out, _) = Dfake
61+
(Tmrd_out, _), (Tmpd_out, _) = Dtrue
62+
mrdloss, mrdrlosses, mrdglosses, _, _ = self.discriminator_loss(Tmrd_out, Fmrd_out)
63+
mpdloss, mpdrlosses, mpdglosses, _, _ = self.discriminator_loss(Tmpd_out, Fmpd_out)
64+
loss = mrdloss + mpdloss
65+
return loss, {'DmrdlossF': mrdglosses, 'DmrdlossT': mrdrlosses, 'DmpdlossT': mpdrlosses,
66+
'DmpdlossF': mpdglosses}
67+
68+
def feature_loss(self, fmap_r, fmap_g):
69+
loss = 0
70+
for dr, dg in zip(fmap_r, fmap_g):
71+
for rl, gl in zip(dr, dg):
72+
loss += torch.mean(torch.abs(rl - gl))
73+
74+
return loss * 2
75+
76+
def GDloss(self, GDfake, GDtrue):
77+
loss = 0
78+
gen_losses = []
79+
mrd_losses = 0
80+
mpd_losses = 0
81+
(mrd_out, Fmrd_featrue), (mpd_out, Fmpd_featrue) = GDfake
82+
(_, Tmrd_featrue), (_, Tmpd_featrue) = GDtrue
83+
for dg in mrd_out:
84+
l = torch.mean((1 - dg) ** 2)
85+
gen_losses.append(l.item())
86+
# loss += l
87+
mrd_losses = l + mrd_losses
88+
89+
for dg in mpd_out:
90+
l = torch.mean((1 - dg) ** 2)
91+
gen_losses.append(l.item())
92+
# loss += l
93+
mpd_losses = l + mpd_losses
94+
95+
mrd_featrue_loss = self.feature_loss(Tmrd_featrue, Fmrd_featrue)
96+
mpd_featrue_loss = self.feature_loss(Tmpd_featrue, Fmpd_featrue)
97+
# loss +=msd_featrue_loss
98+
# loss +=mpd_featrue_loss
99+
loss = mpd_featrue_loss + mpd_losses + mrd_losses+mrd_featrue_loss
100+
# (msd_losses, mpd_losses), (msd_featrue_loss, mpd_featrue_loss), gen_losses
101+
return loss, {'Gmrdloss': mrd_losses, 'Gmpdloss': mpd_losses, 'Gmrd_featrue_loss': mrd_featrue_loss,
102+
'Dmpd_featrue_loss': mpd_featrue_loss}
103+
104+
# def Auxloss(self,Goutput, sample):
105+
#
106+
# Gmel=self.mel.dynamic_range_compression_torch(self.mel(Goutput['audio'].squeeze(1)))
107+
# # Rmel=sample['mel']
108+
# Rmel = self.mel.dynamic_range_compression_torch(self.mel(sample['audio'].squeeze(1)))
109+
# loss=self.L1loss(Gmel, Rmel)*self.labauxloss
110+
# return loss,{'auxloss':loss}
111+
112+
def Auxloss(self, Goutput, sample, step):
113+
114+
# Gmel=self.mel.dynamic_range_compression_torch(self.mel(Goutput['audio'].squeeze(1)))
115+
# # Rmel=sample['mel']
116+
# Rmel = self.mel.dynamic_range_compression_torch(self.mel(sample['audio'].squeeze(1)))
117+
detach_uv = False
118+
if step < self.deuv:
119+
detach_uv = True
120+
121+
#
122+
# lossddsp, (loss_rss, loss_uv) = self.ddsploss(Goutput['ddspwav'].squeeze(1), Goutput['s_h'],
123+
# sample['audio'].squeeze(1),sample['uv'].float(),
124+
# detach_uv=detach_uv,
125+
# uv_tolerance=0.15)
126+
127+
# lossddsp=0
128+
# loss_rss=0
129+
# loss_uv=0
130+
131+
132+
sc_loss, mag_loss = self.stft.stft(Goutput['audio'].squeeze(1), sample['audio'].squeeze(1))
133+
loss = (sc_loss + mag_loss) * self.labauxloss
134+
return loss, {'auxloss': loss, 'auxloss_sc_loss': sc_loss, 'auxloss_mag_loss': mag_loss,}

0 commit comments

Comments
 (0)