update fargan to match version 45

This commit is contained in:
Jean-Marc Valin 2023-10-10 00:51:57 -04:00
parent d1c5b32add
commit 9e76a7bfb8
No known key found for this signature in database
GPG key ID: 531A52533318F00A
7 changed files with 196 additions and 84 deletions

View file

@ -114,20 +114,25 @@ if __name__ == '__main__':
for i, (features, periods, target, lpc) in enumerate(tepoch):
optimizer.zero_grad()
features = features.to(device)
lpc = lpc.to(device)
#lpc = torch.tensor(fargan.interp_lpc(lpc.numpy(), 4))
#print("interp size", lpc.shape)
#lpc = lpc.to(device)
#lpc = lpc*(args.gamma**torch.arange(1,17, device=device))
#lpc = fargan.interp_lpc(lpc, 4)
periods = periods.to(device)
if (np.random.rand() > 0.1):
target = target[:, :sequence_length*160]
lpc = lpc[:,:sequence_length,:]
#lpc = lpc[:,:sequence_length*4,:]
features = features[:,:sequence_length+4,:]
periods = periods[:,:sequence_length+4]
else:
target=target[::2, :]
lpc=lpc[::2,:]
#lpc=lpc[::2,:]
features=features[::2,:]
periods=periods[::2,:]
target = target.to(device)
target = fargan.analysis_filter(target, lpc[:,:,:], gamma=args.gamma)
#print(target.shape, lpc.shape)
#target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma)
#nb_pre = random.randrange(1, 6)
nb_pre = 2
@ -135,9 +140,9 @@ if __name__ == '__main__':
sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None)
sig = torch.cat([pre, sig], -1)
cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], sig[:, nb_pre*160:nb_pre*160+80])
cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+160], sig[:, nb_pre*160:nb_pre*160+160])
specc_loss = spect_loss(sig, target.detach())
loss = .00*cont_loss + specc_loss
loss = .03*cont_loss + specc_loss
loss.backward()
optimizer.step()