update fargan to match version 45

This commit is contained in:
Jean-Marc Valin 2023-10-10 00:51:57 -04:00
parent d1c5b32add
commit 9e76a7bfb8
No known key found for this signature in database
GPG key ID: 531A52533318F00A
7 changed files with 196 additions and 84 deletions

View file

@ -132,6 +132,10 @@ states = None
spect_loss = MultiResolutionSTFTLoss(device).to(device) spect_loss = MultiResolutionSTFTLoss(device).to(device)
for param in model.parameters():
param.requires_grad = False
batch_count = 0
if __name__ == '__main__': if __name__ == '__main__':
model.to(device) model.to(device)
disc.to(device) disc.to(device)
@ -153,22 +157,28 @@ if __name__ == '__main__':
print(f"training epoch {epoch}...") print(f"training epoch {epoch}...")
with tqdm.tqdm(dataloader, unit='batch') as tepoch: with tqdm.tqdm(dataloader, unit='batch') as tepoch:
for i, (features, periods, target, lpc) in enumerate(tepoch): for i, (features, periods, target, lpc) in enumerate(tepoch):
if epoch == 1 and i == 400:
for param in model.parameters():
param.requires_grad = True
optimizer.zero_grad() optimizer.zero_grad()
features = features.to(device) features = features.to(device)
lpc = lpc.to(device) #lpc = lpc.to(device)
#lpc = lpc*(args.gamma**torch.arange(1,17, device=device))
#lpc = fargan.interp_lpc(lpc, 4)
periods = periods.to(device) periods = periods.to(device)
if True: if True:
target = target[:, :sequence_length*160] target = target[:, :sequence_length*160]
lpc = lpc[:,:sequence_length,:] #lpc = lpc[:,:sequence_length*4,:]
features = features[:,:sequence_length+4,:] features = features[:,:sequence_length+4,:]
periods = periods[:,:sequence_length+4] periods = periods[:,:sequence_length+4]
else: else:
target=target[::2, :] target=target[::2, :]
lpc=lpc[::2,:] #lpc=lpc[::2,:]
features=features[::2,:] features=features[::2,:]
periods=periods[::2,:] periods=periods[::2,:]
target = target.to(device) target = target.to(device)
target = fargan.analysis_filter(target, lpc[:,:,:], gamma=args.gamma) #target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma)
#nb_pre = random.randrange(1, 6) #nb_pre = random.randrange(1, 6)
nb_pre = 2 nb_pre = 2
@ -208,7 +218,7 @@ if __name__ == '__main__':
cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], output[:, nb_pre*160:nb_pre*160+80]) cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], output[:, nb_pre*160:nb_pre*160+80])
specc_loss = spect_loss(output, target.detach()) specc_loss = spect_loss(output, target.detach())
reg_loss = args.reg_weight * (.00*cont_loss + specc_loss) reg_loss = (.00*cont_loss + specc_loss)
loss_gen = 0 loss_gen = 0
for scale in scores_gen: for scale in scores_gen:
@ -216,7 +226,8 @@ if __name__ == '__main__':
feat_loss = args.fmap_weight * fmap_loss(scores_real, scores_gen) feat_loss = args.fmap_weight * fmap_loss(scores_real, scores_gen)
gen_loss = reg_loss + feat_loss + loss_gen reg_weight = args.reg_weight + 15./(1 + (batch_count/7600.))
gen_loss = reg_weight * reg_loss + feat_loss + loss_gen
model.zero_grad() model.zero_grad()
@ -238,12 +249,14 @@ if __name__ == '__main__':
tepoch.set_postfix(cont_loss=f"{running_cont_loss/(i+1):8.5f}", tepoch.set_postfix(cont_loss=f"{running_cont_loss/(i+1):8.5f}",
reg_weight=f"{reg_weight:8.5f}",
gen_loss=f"{running_gen_loss/(i+1):8.5f}", gen_loss=f"{running_gen_loss/(i+1):8.5f}",
disc_loss=f"{running_disc_loss/(i+1):8.5f}", disc_loss=f"{running_disc_loss/(i+1):8.5f}",
fmap_loss=f"{running_fmap_loss/(i+1):8.5f}", fmap_loss=f"{running_fmap_loss/(i+1):8.5f}",
reg_loss=f"{running_reg_loss/(i+1):8.5f}", reg_loss=f"{running_reg_loss/(i+1):8.5f}",
wc = f"{running_wc/(i+1):8.5f}", wc = f"{running_wc/(i+1):8.5f}",
) )
batch_count = batch_count + 1
# save checkpoint # save checkpoint
checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_adv_{epoch}.pth') checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_adv_{epoch}.pth')

View file

@ -1,5 +1,6 @@
import torch import torch
import numpy as np import numpy as np
import fargan
class FARGANDataset(torch.utils.data.Dataset): class FARGANDataset(torch.utils.data.Dataset):
def __init__(self, def __init__(self,
@ -34,7 +35,8 @@ class FARGANDataset(torch.utils.data.Dataset):
sizeof = self.features.strides[-1] sizeof = self.features.strides[-1]
self.features = np.lib.stride_tricks.as_strided(self.features, shape=(self.nb_sequences, self.sequence_length*2+4, nb_features), self.features = np.lib.stride_tricks.as_strided(self.features, shape=(self.nb_sequences, self.sequence_length*2+4, nb_features),
strides=(self.sequence_length*self.nb_features*sizeof, self.nb_features*sizeof, sizeof)) strides=(self.sequence_length*self.nb_features*sizeof, self.nb_features*sizeof, sizeof))
self.periods = np.round(50*self.features[:,:,self.nb_used_features-2]+100).astype('int') #self.periods = np.round(50*self.features[:,:,self.nb_used_features-2]+100).astype('int')
self.periods = np.round(np.clip(256./2**(self.features[:,:,self.nb_used_features-2]+1.5), 32, 255)).astype('int')
self.lpc = self.features[:, :, self.nb_used_features:] self.lpc = self.features[:, :, self.nb_used_features:]
self.features = self.features[:, :, :self.nb_used_features] self.features = self.features[:, :, :self.nb_used_features]
@ -51,5 +53,9 @@ class FARGANDataset(torch.utils.data.Dataset):
lpc = self.lpc[index, 4:, :].copy() lpc = self.lpc[index, 4:, :].copy()
data = self.data[index, :].copy().astype(np.float32) / 2**15 data = self.data[index, :].copy().astype(np.float32) / 2**15
periods = self.periods[index, :].copy() periods = self.periods[index, :].copy()
#lpc = lpc*(self.gamma**np.arange(1,17))
#lpc=lpc[None,:,:]
#lpc = fargan.interp_lpc(lpc, 4)
#lpc=lpc[0,:,:]
return features, periods, data, lpc return features, periods, data, lpc

View file

@ -4,6 +4,8 @@ from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
import filters import filters
from torch.nn.utils import weight_norm from torch.nn.utils import weight_norm
#from convert_lsp import lpc_to_lsp, lsp_to_lpc
from rc import lpc2rc, rc2lpc
Fs = 16000 Fs = 16000
@ -27,6 +29,27 @@ def sig_loss(y_true, y_pred):
p = y_pred/(1e-15+torch.norm(y_pred, dim=-1, p=2, keepdim=True)) p = y_pred/(1e-15+torch.norm(y_pred, dim=-1, p=2, keepdim=True))
return torch.mean(1.-torch.sum(p*t, dim=-1)) return torch.mean(1.-torch.sum(p*t, dim=-1))
def interp_lpc(lpc, factor):
#print(lpc.shape)
#f = (np.arange(factor)+.5*((factor+1)%2))/factor
lsp = torch.atanh(lpc2rc(lpc))
#print("lsp0:")
#print(lsp)
shape = lsp.shape
#print("shape is", shape)
shape = (shape[0], shape[1]*factor, shape[2])
interp_lsp = torch.zeros(shape, device=lpc.device)
for k in range(factor):
f = (k+.5*((factor+1)%2))/factor
interp = (1-f)*lsp[:,:-1,:] + f*lsp[:,1:,:]
interp_lsp[:,factor//2+k:-(factor//2):factor,:] = interp
for k in range(factor//2):
interp_lsp[:,k,:] = interp_lsp[:,factor//2,:]
for k in range((factor+1)//2):
interp_lsp[:,-k-1,:] = interp_lsp[:,-(factor+3)//2,:]
#print("lsp:")
#print(interp_lsp)
return rc2lpc(torch.tanh(interp_lsp))
def analysis_filter(x, lpc, nb_subframes=4, subframe_size=40, gamma=.9): def analysis_filter(x, lpc, nb_subframes=4, subframe_size=40, gamma=.9):
device = x.device device = x.device
@ -39,9 +62,9 @@ def analysis_filter(x, lpc, nb_subframes=4, subframe_size=40, gamma=.9):
x = torch.reshape(x, (batch_size, nb_frames*nb_subframes, subframe_size)) x = torch.reshape(x, (batch_size, nb_frames*nb_subframes, subframe_size))
out = torch.zeros((batch_size, 0), device=device) out = torch.zeros((batch_size, 0), device=device)
if gamma is not None: #if gamma is not None:
bw = gamma**(torch.arange(1, 17, device=device)) # bw = gamma**(torch.arange(1, 17, device=device))
lpc = lpc*bw[None,None,:] # lpc = lpc*bw[None,None,:]
ones = torch.ones((*(lpc.shape[:-1]), 1), device=device) ones = torch.ones((*(lpc.shape[:-1]), 1), device=device)
zeros = torch.zeros((*(lpc.shape[:-1]), subframe_size-1), device=device) zeros = torch.zeros((*(lpc.shape[:-1]), subframe_size-1), device=device)
a = torch.cat([ones, lpc], -1) a = torch.cat([ones, lpc], -1)
@ -127,30 +150,34 @@ class FWConv(nn.Module):
out = self.glu(torch.tanh(self.conv(xcat))) out = self.glu(torch.tanh(self.conv(xcat)))
return out, xcat[:,self.in_size:] return out, xcat[:,self.in_size:]
def n(x):
return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.)
class FARGANCond(nn.Module): class FARGANCond(nn.Module):
def __init__(self, feature_dim=20, cond_size=256, pembed_dims=64): def __init__(self, feature_dim=20, cond_size=256, pembed_dims=12):
super(FARGANCond, self).__init__() super(FARGANCond, self).__init__()
self.feature_dim = feature_dim self.feature_dim = feature_dim
self.cond_size = cond_size self.cond_size = cond_size
self.pembed = nn.Embedding(256, pembed_dims) self.pembed = nn.Embedding(224, pembed_dims)
self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, self.cond_size, bias=False) self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, 64, bias=False)
self.fconv1 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False) self.fconv1 = nn.Conv1d(64, 128, kernel_size=3, padding='valid', bias=False)
self.fconv2 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False) self.fconv2 = nn.Conv1d(128, 80*4, kernel_size=3, padding='valid', bias=False)
self.fdense2 = nn.Linear(self.cond_size, 80*4, bias=False)
self.apply(init_weights) self.apply(init_weights)
nb_params = sum(p.numel() for p in self.parameters())
print(f"cond model: {nb_params} weights")
def forward(self, features, period): def forward(self, features, period):
p = self.pembed(period) p = self.pembed(period-32)
features = torch.cat((features, p), -1) features = torch.cat((features, p), -1)
tmp = torch.tanh(self.fdense1(features)) tmp = torch.tanh(self.fdense1(features))
tmp = tmp.permute(0, 2, 1) tmp = tmp.permute(0, 2, 1)
tmp = torch.tanh(self.fconv1(tmp)) tmp = torch.tanh(self.fconv1(tmp))
tmp = torch.tanh(self.fconv2(tmp)) tmp = torch.tanh(self.fconv2(tmp))
tmp = tmp.permute(0, 2, 1) tmp = tmp.permute(0, 2, 1)
tmp = torch.tanh(self.fdense2(tmp)) #tmp = torch.tanh(self.fdense2(tmp))
return tmp return tmp
class FARGANSub(nn.Module): class FARGANSub(nn.Module):
@ -160,70 +187,87 @@ class FARGANSub(nn.Module):
self.subframe_size = subframe_size self.subframe_size = subframe_size
self.nb_subframes = nb_subframes self.nb_subframes = nb_subframes
self.cond_size = cond_size self.cond_size = cond_size
self.cond_gain_dense = nn.Linear(80, 1)
#self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False) #self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
self.fwc0 = FWConv(4*self.subframe_size+80, self.cond_size) self.fwc0 = FWConv(2*self.subframe_size+80+4, self.cond_size)
self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False) self.gru1 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False) self.gru2 = nn.GRUCell(self.cond_size+2*self.subframe_size, 128, bias=False)
self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False) self.gru3 = nn.GRUCell(128+2*self.subframe_size, 128, bias=False)
self.gru3 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
self.dense1_glu = GLU(self.cond_size) self.dense1_glu = GLU(self.cond_size)
self.dense2_glu = GLU(self.cond_size)
self.gru1_glu = GLU(self.cond_size) self.gru1_glu = GLU(self.cond_size)
self.gru2_glu = GLU(self.cond_size) self.gru2_glu = GLU(128)
self.gru3_glu = GLU(self.cond_size) self.gru3_glu = GLU(128)
self.ptaps_dense = nn.Linear(4*self.cond_size, 5) self.skip_glu = GLU(self.cond_size)
#self.ptaps_dense = nn.Linear(4*self.cond_size, 5)
self.sig_dense_out = nn.Linear(4*self.cond_size, self.subframe_size, bias=False) self.skip_dense = nn.Linear(2*128+2*self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
self.gain_dense_out = nn.Linear(4*self.cond_size, 1) self.sig_dense_out = nn.Linear(self.cond_size, self.subframe_size, bias=False)
self.gain_dense_out = nn.Linear(self.cond_size, 4)
self.apply(init_weights) self.apply(init_weights)
nb_params = sum(p.numel() for p in self.parameters())
print(f"subframe model: {nb_params} weights")
def forward(self, cond, prev, exc_mem, phase, period, states, gain=None): def forward(self, cond, prev_pred, exc_mem, period, states, gain=None):
device = exc_mem.device device = exc_mem.device
#print(cond.shape, prev.shape) #print(cond.shape, prev.shape)
dump_signal(prev, 'prev_in.f32') cond = n(cond)
dump_signal(gain, 'gain0.f32')
idx = 256-torch.clamp(period[:,None], min=self.subframe_size+2, max=254) gain = torch.exp(self.cond_gain_dense(cond))
dump_signal(gain, 'gain1.f32')
idx = 256-period[:,None]
rng = torch.arange(self.subframe_size+4, device=device) rng = torch.arange(self.subframe_size+4, device=device)
idx = idx + rng[None,:] - 2 idx = idx + rng[None,:] - 2
mask = idx >= 256
idx = idx - mask*period[:,None]
pred = torch.gather(exc_mem, 1, idx) pred = torch.gather(exc_mem, 1, idx)
pred = pred/(1e-5+gain) pred = n(pred/(1e-5+gain))
prev = prev/(1e-5+gain) prev = exc_mem[:,-self.subframe_size:]
dump_signal(prev, 'prev_in.f32')
prev = n(prev/(1e-5+gain))
dump_signal(prev, 'pitch_exc.f32') dump_signal(prev, 'pitch_exc.f32')
dump_signal(exc_mem, 'exc_mem.f32') dump_signal(exc_mem, 'exc_mem.f32')
tmp = torch.cat((cond, pred[:,2:-2], prev, phase), 1) tmp = torch.cat((cond, pred, prev), 1)
#tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
fwc0_out, fwc0_state = self.fwc0(tmp, states[3])
dense2_out = self.dense2_glu(torch.tanh(self.sig_dense2(fwc0_out)))
gru1_state = self.gru1(dense2_out, states[0])
gru1_out = self.gru1_glu(gru1_state)
gru2_state = self.gru2(gru1_out, states[1])
gru2_out = self.gru2_glu(gru2_state)
gru3_state = self.gru3(gru2_out, states[2])
gru3_out = self.gru3_glu(gru3_state)
gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, dense2_out], 1)
sig_out = torch.tanh(self.sig_dense_out(gru3_out))
dump_signal(sig_out, 'exc_out.f32')
taps = self.ptaps_dense(gru3_out)
taps = .2*taps + torch.exp(taps)
taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True))
dump_signal(taps, 'taps.f32')
#fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:] #fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
fpitch = pred[:,2:-2] fpitch = pred[:,2:-2]
pitch_gain = torch.exp(self.gain_dense_out(gru3_out)) #tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
fwc0_out, fwc0_state = self.fwc0(tmp, states[3])
fwc0_out = n(fwc0_out)
pitch_gain = torch.sigmoid(self.gain_dense_out(fwc0_out))
gru1_state = self.gru1(torch.cat([fwc0_out, pitch_gain[:,0:1]*fpitch, prev], 1), states[0])
gru1_out = self.gru1_glu(n(gru1_state))
gru1_out = n(gru1_out)
gru2_state = self.gru2(torch.cat([gru1_out, pitch_gain[:,1:2]*fpitch, prev], 1), states[1])
gru2_out = self.gru2_glu(n(gru2_state))
gru2_out = n(gru2_out)
gru3_state = self.gru3(torch.cat([gru2_out, pitch_gain[:,2:3]*fpitch, prev], 1), states[2])
gru3_out = self.gru3_glu(n(gru3_state))
gru3_out = n(gru3_out)
gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, fwc0_out], 1)
skip_out = torch.tanh(self.skip_dense(torch.cat([gru3_out, pitch_gain[:,3:4]*fpitch, prev], 1)))
skip_out = self.skip_glu(n(skip_out))
sig_out = torch.tanh(self.sig_dense_out(skip_out))
dump_signal(sig_out, 'exc_out.f32')
#taps = self.ptaps_dense(gru3_out)
#taps = .2*taps + torch.exp(taps)
#taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True))
#dump_signal(taps, 'taps.f32')
dump_signal(pitch_gain, 'pgain.f32') dump_signal(pitch_gain, 'pgain.f32')
sig_out = (sig_out + pitch_gain*fpitch) * gain #sig_out = (sig_out + pitch_gain*fpitch) * gain
sig_out = sig_out * gain
exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1) exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1)
prev_pred = torch.cat([prev_pred[:,self.subframe_size:], fpitch], 1)
dump_signal(sig_out, 'sig_out.f32') dump_signal(sig_out, 'sig_out.f32')
return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, fwc0_state) return sig_out, exc_mem, prev_pred, (gru1_state, gru2_state, gru3_state, fwc0_state)
class FARGAN(nn.Module): class FARGAN(nn.Module):
def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None): def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None):
@ -242,37 +286,30 @@ class FARGAN(nn.Module):
device = features.device device = features.device
batch_size = features.size(0) batch_size = features.size(0)
phase_real, phase_imag = gen_phase_embedding(period[:, 3:-1], self.frame_size) prev = torch.zeros(batch_size, 256, device=device)
#np.round(32000*phase.detach().numpy()).astype('int16').tofile('phase.sw')
prev = torch.zeros(batch_size, self.subframe_size, device=device)
exc_mem = torch.zeros(batch_size, 256, device=device) exc_mem = torch.zeros(batch_size, 256, device=device)
nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0 nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0
states = ( states = (
torch.zeros(batch_size, self.cond_size, device=device), torch.zeros(batch_size, self.cond_size, device=device),
torch.zeros(batch_size, self.cond_size, device=device), torch.zeros(batch_size, 128, device=device),
torch.zeros(batch_size, self.cond_size, device=device), torch.zeros(batch_size, 128, device=device),
torch.zeros(batch_size, (4*self.subframe_size+80)*2, device=device) torch.zeros(batch_size, (2*self.subframe_size+80+4)*2, device=device)
) )
sig = torch.zeros((batch_size, 0), device=device) sig = torch.zeros((batch_size, 0), device=device)
cond = self.cond_net(features, period) cond = self.cond_net(features, period)
if pre is not None: if pre is not None:
prev[:,:] = pre[:, self.frame_size-self.subframe_size : self.frame_size]
exc_mem[:,-self.frame_size:] = pre[:, :self.frame_size] exc_mem[:,-self.frame_size:] = pre[:, :self.frame_size]
start = 1 if nb_pre_frames>0 else 0 start = 1 if nb_pre_frames>0 else 0
for n in range(start, nb_frames+nb_pre_frames): for n in range(start, nb_frames+nb_pre_frames):
for k in range(self.nb_subframes): for k in range(self.nb_subframes):
pos = n*self.frame_size + k*self.subframe_size pos = n*self.frame_size + k*self.subframe_size
preal = phase_real[:, pos:pos+self.subframe_size]
pimag = phase_imag[:, pos:pos+self.subframe_size]
phase = torch.cat([preal, pimag], 1)
#print("now: ", preal.shape, prev.shape, sig_in.shape) #print("now: ", preal.shape, prev.shape, sig_in.shape)
pitch = period[:, 3+n] pitch = period[:, 3+n]
gain = .03*10**(0.5*features[:, 3+n, 0:1]/np.sqrt(18.0)) gain = .03*10**(0.5*features[:, 3+n, 0:1]/np.sqrt(18.0))
#gain = gain[:,:,None] #gain = gain[:,:,None]
out, exc_mem, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, phase, pitch, states, gain=gain) out, exc_mem, prev, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, pitch, states, gain=gain)
if n < nb_pre_frames: if n < nb_pre_frames:
out = pre[:, pos:pos+self.subframe_size] out = pre[:, pos:pos+self.subframe_size]
@ -280,6 +317,5 @@ class FARGAN(nn.Module):
else: else:
sig = torch.cat([sig, out], 1) sig = torch.cat([sig, out], 1)
prev = out
states = [s.detach() for s in states] states = [s.detach() for s in states]
return sig, states return sig, states

29
dnn/torch/fargan/rc.py Normal file
View file

@ -0,0 +1,29 @@
import torch
def rc2lpc(rc):
order = rc.shape[-1]
lpc=rc[...,0:1]
for i in range(1, order):
lpc = torch.cat([lpc + rc[...,i:i+1]*torch.flip(lpc,dims=(-1,)), rc[...,i:i+1]], -1)
#print("to:", lpc)
return lpc
def lpc2rc(lpc):
order = lpc.shape[-1]
rc = lpc[...,-1:]
for i in range(order-1, 0, -1):
ki = lpc[...,-1:]
lpc = lpc[...,:-1]
lpc = (lpc - ki*torch.flip(lpc,dims=(-1,)))/(1 - ki*ki)
rc = torch.cat([lpc[...,-1:] , rc], -1)
return rc
if __name__ == "__main__":
rc = torch.tensor([[.5, -.5, .6, -.6]])
print(rc)
lpc = rc2lpc(rc)
print(lpc)
rc2 = lpc2rc(lpc)
print(rc2)

View file

@ -44,7 +44,9 @@ class SpectralConvergenceLoss(torch.nn.Module):
Returns: Returns:
Tensor: Spectral convergence loss value. Tensor: Spectral convergence loss value.
""" """
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") x_mag = torch.sqrt(x_mag)
y_mag = torch.sqrt(y_mag)
return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1)
class LogSTFTMagnitudeLoss(torch.nn.Module): class LogSTFTMagnitudeLoss(torch.nn.Module):
"""Log STFT magnitude loss module.""" """Log STFT magnitude loss module."""
@ -136,26 +138,26 @@ class STFTLoss(torch.nn.Module):
class MultiResolutionSTFTLoss(torch.nn.Module): class MultiResolutionSTFTLoss(torch.nn.Module):
def __init__(self, '''def __init__(self,
device, device,
fft_sizes=[2048, 1024, 512, 256, 128, 64], fft_sizes=[2048, 1024, 512, 256, 128, 64],
hop_sizes=[512, 256, 128, 64, 32, 16], hop_sizes=[512, 256, 128, 64, 32, 16],
win_lengths=[2048, 1024, 512, 256, 128, 64], win_lengths=[2048, 1024, 512, 256, 128, 64],
window="hann_window"): window="hann_window"):'''
'''def __init__(self, '''def __init__(self,
device, device,
fft_sizes=[2048, 1024, 512, 256, 128, 64], fft_sizes=[2048, 1024, 512, 256, 128, 64],
hop_sizes=[256, 128, 64, 32, 16, 8], hop_sizes=[256, 128, 64, 32, 16, 8],
win_lengths=[1024, 512, 256, 128, 64, 32], win_lengths=[1024, 512, 256, 128, 64, 32],
window="hann_window"):''' window="hann_window"):'''
'''def __init__(self, def __init__(self,
device, device,
fft_sizes=[2560, 1280, 640, 320, 160, 80], fft_sizes=[2560, 1280, 640, 320, 160, 80],
hop_sizes=[640, 320, 160, 80, 40, 20], hop_sizes=[640, 320, 160, 80, 40, 20],
win_lengths=[2560, 1280, 640, 320, 160, 80], win_lengths=[2560, 1280, 640, 320, 160, 80],
window="hann_window"):''' window="hann_window"):
super(MultiResolutionSTFTLoss, self).__init__() super(MultiResolutionSTFTLoss, self).__init__()
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)

View file

@ -48,7 +48,9 @@ model.load_state_dict(checkpoint['state_dict'], strict=False)
features = np.reshape(np.memmap(features_file, dtype='float32', mode='r'), (1, -1, nb_features)) features = np.reshape(np.memmap(features_file, dtype='float32', mode='r'), (1, -1, nb_features))
lpc = features[:,4-1:-1,nb_used_features:] lpc = features[:,4-1:-1,nb_used_features:]
features = features[:, :, :nb_used_features] features = features[:, :, :nb_used_features]
periods = np.round(50*features[:,:,nb_used_features-2]+100).astype('int') #periods = np.round(50*features[:,:,nb_used_features-2]+100).astype('int')
periods = np.round(np.clip(256./2**(features[:,:,nb_used_features-2]+1.5), 32, 255)).astype('int')
nb_frames = features.shape[1] nb_frames = features.shape[1]
#nb_frames = 1000 #nb_frames = 1000
@ -90,18 +92,37 @@ def inverse_perceptual_weighting (pw_signal, filters, weighting_vector):
buffer[:] = out_sig_frame[-16:] buffer[:] = out_sig_frame[-16:]
return signal return signal
def inverse_perceptual_weighting40 (pw_signal, filters):
#inverse perceptual weighting= H_preemph / W(z/gamma)
signal = np.zeros_like(pw_signal)
buffer = np.zeros(16)
num_frames = pw_signal.shape[0] //40
assert num_frames == filters.shape[0]
for frame_idx in range(0, num_frames):
in_frame = pw_signal[frame_idx*40: (frame_idx+1)*40][:]
out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer)
signal[frame_idx*40: (frame_idx+1)*40] = out_sig_frame[:]
buffer[:] = out_sig_frame[-16:]
return signal
from scipy.signal import lfilter
if __name__ == '__main__': if __name__ == '__main__':
model.to(device) model.to(device)
features = torch.tensor(features).to(device) features = torch.tensor(features).to(device)
#lpc = torch.tensor(lpc).to(device) #lpc = torch.tensor(lpc).to(device)
periods = torch.tensor(periods).to(device) periods = torch.tensor(periods).to(device)
weighting = gamma**np.arange(1, 17)
lpc = lpc*weighting
lpc = fargan.interp_lpc(torch.tensor(lpc), 4).numpy()
sig, _ = model(features, periods, nb_frames - 4) sig, _ = model(features, periods, nb_frames - 4)
weighting_vector = np.array([gamma**i for i in range(16,0,-1)]) #weighting_vector = np.array([gamma**i for i in range(16,0,-1)])
sig = sig.detach().numpy().flatten() sig = sig.detach().numpy().flatten()
sig = inverse_perceptual_weighting(sig, lpc[0,:,:], weighting_vector) sig = lfilter(np.array([1.]), np.array([1., -.85]), sig)
#sig = inverse_perceptual_weighting40(sig, lpc[0,:,:])
pcm = np.round(32768*np.clip(sig, a_max=.99, a_min=-.99)).astype('int16') pcm = np.round(32768*np.clip(sig, a_max=.99, a_min=-.99)).astype('int16')
pcm.tofile(signal_file) pcm.tofile(signal_file)

View file

@ -114,20 +114,25 @@ if __name__ == '__main__':
for i, (features, periods, target, lpc) in enumerate(tepoch): for i, (features, periods, target, lpc) in enumerate(tepoch):
optimizer.zero_grad() optimizer.zero_grad()
features = features.to(device) features = features.to(device)
lpc = lpc.to(device) #lpc = torch.tensor(fargan.interp_lpc(lpc.numpy(), 4))
#print("interp size", lpc.shape)
#lpc = lpc.to(device)
#lpc = lpc*(args.gamma**torch.arange(1,17, device=device))
#lpc = fargan.interp_lpc(lpc, 4)
periods = periods.to(device) periods = periods.to(device)
if (np.random.rand() > 0.1): if (np.random.rand() > 0.1):
target = target[:, :sequence_length*160] target = target[:, :sequence_length*160]
lpc = lpc[:,:sequence_length,:] #lpc = lpc[:,:sequence_length*4,:]
features = features[:,:sequence_length+4,:] features = features[:,:sequence_length+4,:]
periods = periods[:,:sequence_length+4] periods = periods[:,:sequence_length+4]
else: else:
target=target[::2, :] target=target[::2, :]
lpc=lpc[::2,:] #lpc=lpc[::2,:]
features=features[::2,:] features=features[::2,:]
periods=periods[::2,:] periods=periods[::2,:]
target = target.to(device) target = target.to(device)
target = fargan.analysis_filter(target, lpc[:,:,:], gamma=args.gamma) #print(target.shape, lpc.shape)
#target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma)
#nb_pre = random.randrange(1, 6) #nb_pre = random.randrange(1, 6)
nb_pre = 2 nb_pre = 2
@ -135,9 +140,9 @@ if __name__ == '__main__':
sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None) sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None)
sig = torch.cat([pre, sig], -1) sig = torch.cat([pre, sig], -1)
cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], sig[:, nb_pre*160:nb_pre*160+80]) cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+160], sig[:, nb_pre*160:nb_pre*160+160])
specc_loss = spect_loss(sig, target.detach()) specc_loss = spect_loss(sig, target.detach())
loss = .00*cont_loss + specc_loss loss = .03*cont_loss + specc_loss
loss.backward() loss.backward()
optimizer.step() optimizer.step()