|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torch.nn.functional as F |
| 4 | + |
| 5 | +from modules.ddsp.loss import HybridLoss |
| 6 | +from modules.loss.stft_loss import warp_stft |
| 7 | +from utils.wav2mel import PitchAdjustableMelSpectrogram |
| 8 | + |
| 9 | + |
| 10 | +class nsf_univloss_msd(nn.Module): |
| 11 | + def __init__(self, config: dict): |
| 12 | + super().__init__() |
| 13 | + self.mel = PitchAdjustableMelSpectrogram(sample_rate=config['audio_sample_rate'], |
| 14 | + n_fft=config['fft_size'], |
| 15 | + win_length=config['win_size'], |
| 16 | + hop_length=config['hop_size'], |
| 17 | + f_min=config['fmin'], |
| 18 | + f_max=config['fmax_for_loss'], |
| 19 | + n_mels=config['audio_num_mel_bins'], ) |
| 20 | + self.L1loss = nn.L1Loss() |
| 21 | + self.labauxloss = config.get('lab_aux_loss', 45) |
| 22 | + self.labddsploss=config.get('lab_ddsp_loss', 2) |
| 23 | + # self.stft=warp_stft({'fft_sizes':[1024, 2048, 512,],'hop_sizes':[120, 240, 50,],'win_lengths':[600, 1200, 240,]}) |
| 24 | + |
| 25 | + # self.stft = warp_stft( |
| 26 | + # {'fft_sizes': [2048, 2048, 4096, 1024, 512, 256, 128], 'hop_sizes': [512, 240, 480, 100, 50, 25, 12], |
| 27 | + # 'win_lengths': [2048, 1200, 2400, 480, 240, 120, 60]}) |
| 28 | + self.stft = warp_stft({'fft_sizes': config['loss_fft_sizes'], 'hop_sizes': config['loss_hop_sizes'], |
| 29 | + 'win_lengths': config['loss_win_lengths']}) |
| 30 | + |
| 31 | + self.deuv = config.get('detuv', 2000) |
| 32 | + |
| 33 | + # self.ddsploss = HybridLoss(block_size=config['hop_size'], fft_min=config['ddsp_fftmin'], |
| 34 | + # fft_max=config['ddsp_fftmax'], n_scale=config['ddsp_nscale'], |
| 35 | + # lambda_uv=config['ddsp_lambdauv'], device='cuda') |
| 36 | + # fft_sizes = [2048, 4096, 1024, 512, 256, 128], |
| 37 | + # hop_sizes = [240, 480, 100, 50, 25, 12], |
| 38 | + # win_lengths = [1200, 2400, 480, 240, 120, 60] |
| 39 | + |
| 40 | + def discriminator_loss(self, disc_real_outputs, disc_generated_outputs): |
| 41 | + loss = 0 |
| 42 | + rlosses = 0 |
| 43 | + glosses = 0 |
| 44 | + r_losses = [] |
| 45 | + g_losses = [] |
| 46 | + |
| 47 | + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): |
| 48 | + r_loss = torch.mean((1 - dr) ** 2) |
| 49 | + g_loss = torch.mean(dg ** 2) |
| 50 | + loss += r_loss + g_loss |
| 51 | + rlosses += r_loss.item() |
| 52 | + glosses += g_loss.item() |
| 53 | + r_losses.append(r_loss.item()) |
| 54 | + g_losses.append(g_loss.item()) |
| 55 | + |
| 56 | + return loss, rlosses, glosses, r_losses, g_losses |
| 57 | + |
| 58 | + def Dloss(self, Dfake, Dtrue): |
| 59 | + |
| 60 | + (Fmrd_out, _), (Fmpd_out, _) = Dfake |
| 61 | + (Tmrd_out, _), (Tmpd_out, _) = Dtrue |
| 62 | + mrdloss, mrdrlosses, mrdglosses, _, _ = self.discriminator_loss(Tmrd_out, Fmrd_out) |
| 63 | + mpdloss, mpdrlosses, mpdglosses, _, _ = self.discriminator_loss(Tmpd_out, Fmpd_out) |
| 64 | + loss = mrdloss + mpdloss |
| 65 | + return loss, {'DmrdlossF': mrdglosses, 'DmrdlossT': mrdrlosses, 'DmpdlossT': mpdrlosses, |
| 66 | + 'DmpdlossF': mpdglosses} |
| 67 | + |
| 68 | + def feature_loss(self, fmap_r, fmap_g): |
| 69 | + loss = 0 |
| 70 | + for dr, dg in zip(fmap_r, fmap_g): |
| 71 | + for rl, gl in zip(dr, dg): |
| 72 | + loss += torch.mean(torch.abs(rl - gl)) |
| 73 | + |
| 74 | + return loss * 2 |
| 75 | + |
| 76 | + def GDloss(self, GDfake, GDtrue): |
| 77 | + loss = 0 |
| 78 | + gen_losses = [] |
| 79 | + mrd_losses = 0 |
| 80 | + mpd_losses = 0 |
| 81 | + (mrd_out, Fmrd_featrue), (mpd_out, Fmpd_featrue) = GDfake |
| 82 | + (_, Tmrd_featrue), (_, Tmpd_featrue) = GDtrue |
| 83 | + for dg in mrd_out: |
| 84 | + l = torch.mean((1 - dg) ** 2) |
| 85 | + gen_losses.append(l.item()) |
| 86 | + # loss += l |
| 87 | + mrd_losses = l + mrd_losses |
| 88 | + |
| 89 | + for dg in mpd_out: |
| 90 | + l = torch.mean((1 - dg) ** 2) |
| 91 | + gen_losses.append(l.item()) |
| 92 | + # loss += l |
| 93 | + mpd_losses = l + mpd_losses |
| 94 | + |
| 95 | + mrd_featrue_loss = self.feature_loss(Tmrd_featrue, Fmrd_featrue) |
| 96 | + mpd_featrue_loss = self.feature_loss(Tmpd_featrue, Fmpd_featrue) |
| 97 | + # loss +=msd_featrue_loss |
| 98 | + # loss +=mpd_featrue_loss |
| 99 | + loss = mpd_featrue_loss + mpd_losses + mrd_losses+mrd_featrue_loss |
| 100 | + # (msd_losses, mpd_losses), (msd_featrue_loss, mpd_featrue_loss), gen_losses |
| 101 | + return loss, {'Gmrdloss': mrd_losses, 'Gmpdloss': mpd_losses, 'Gmrd_featrue_loss': mrd_featrue_loss, |
| 102 | + 'Dmpd_featrue_loss': mpd_featrue_loss} |
| 103 | + |
| 104 | + # def Auxloss(self,Goutput, sample): |
| 105 | + # |
| 106 | + # Gmel=self.mel.dynamic_range_compression_torch(self.mel(Goutput['audio'].squeeze(1))) |
| 107 | + # # Rmel=sample['mel'] |
| 108 | + # Rmel = self.mel.dynamic_range_compression_torch(self.mel(sample['audio'].squeeze(1))) |
| 109 | + # loss=self.L1loss(Gmel, Rmel)*self.labauxloss |
| 110 | + # return loss,{'auxloss':loss} |
| 111 | + |
| 112 | + def Auxloss(self, Goutput, sample, step): |
| 113 | + |
| 114 | + # Gmel=self.mel.dynamic_range_compression_torch(self.mel(Goutput['audio'].squeeze(1))) |
| 115 | + # # Rmel=sample['mel'] |
| 116 | + # Rmel = self.mel.dynamic_range_compression_torch(self.mel(sample['audio'].squeeze(1))) |
| 117 | + detach_uv = False |
| 118 | + if step < self.deuv: |
| 119 | + detach_uv = True |
| 120 | + |
| 121 | + # |
| 122 | + # lossddsp, (loss_rss, loss_uv) = self.ddsploss(Goutput['ddspwav'].squeeze(1), Goutput['s_h'], |
| 123 | + # sample['audio'].squeeze(1),sample['uv'].float(), |
| 124 | + # detach_uv=detach_uv, |
| 125 | + # uv_tolerance=0.15) |
| 126 | + |
| 127 | + # lossddsp=0 |
| 128 | + # loss_rss=0 |
| 129 | + # loss_uv=0 |
| 130 | + |
| 131 | + |
| 132 | + sc_loss, mag_loss = self.stft.stft(Goutput['audio'].squeeze(1), sample['audio'].squeeze(1)) |
| 133 | + loss = (sc_loss + mag_loss) * self.labauxloss |
| 134 | + return loss, {'auxloss': loss, 'auxloss_sc_loss': sc_loss, 'auxloss_mag_loss': mag_loss,} |
0 commit comments