update fargan to match version 45
This commit is contained in:
parent
d1c5b32add
commit
9e76a7bfb8
7 changed files with 196 additions and 84 deletions
|
@ -114,20 +114,25 @@ if __name__ == '__main__':
|
|||
for i, (features, periods, target, lpc) in enumerate(tepoch):
|
||||
optimizer.zero_grad()
|
||||
features = features.to(device)
|
||||
lpc = lpc.to(device)
|
||||
#lpc = torch.tensor(fargan.interp_lpc(lpc.numpy(), 4))
|
||||
#print("interp size", lpc.shape)
|
||||
#lpc = lpc.to(device)
|
||||
#lpc = lpc*(args.gamma**torch.arange(1,17, device=device))
|
||||
#lpc = fargan.interp_lpc(lpc, 4)
|
||||
periods = periods.to(device)
|
||||
if (np.random.rand() > 0.1):
|
||||
target = target[:, :sequence_length*160]
|
||||
lpc = lpc[:,:sequence_length,:]
|
||||
#lpc = lpc[:,:sequence_length*4,:]
|
||||
features = features[:,:sequence_length+4,:]
|
||||
periods = periods[:,:sequence_length+4]
|
||||
else:
|
||||
target=target[::2, :]
|
||||
lpc=lpc[::2,:]
|
||||
#lpc=lpc[::2,:]
|
||||
features=features[::2,:]
|
||||
periods=periods[::2,:]
|
||||
target = target.to(device)
|
||||
target = fargan.analysis_filter(target, lpc[:,:,:], gamma=args.gamma)
|
||||
#print(target.shape, lpc.shape)
|
||||
#target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma)
|
||||
|
||||
#nb_pre = random.randrange(1, 6)
|
||||
nb_pre = 2
|
||||
|
@ -135,9 +140,9 @@ if __name__ == '__main__':
|
|||
sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None)
|
||||
sig = torch.cat([pre, sig], -1)
|
||||
|
||||
cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], sig[:, nb_pre*160:nb_pre*160+80])
|
||||
cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+160], sig[:, nb_pre*160:nb_pre*160+160])
|
||||
specc_loss = spect_loss(sig, target.detach())
|
||||
loss = .00*cont_loss + specc_loss
|
||||
loss = .03*cont_loss + specc_loss
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue