FARGAN initial commit in Opus

Copied/adapted from LPCNet repo
This commit is contained in:
Jean-Marc Valin 2023-08-30 18:36:09 -04:00
parent 4f4b624209
commit 1b13f6313e
No known key found for this signature in database
GPG key ID: 531A52533318F00A
6 changed files with 804 additions and 0 deletions

View file

@ -0,0 +1,52 @@
import torch
import numpy as np
class FARGANDataset(torch.utils.data.Dataset):
def __init__(self,
feature_file,
signal_file,
frame_size=160,
sequence_length=15,
lookahead=1,
nb_used_features=20,
nb_features=36):
self.frame_size = frame_size
self.sequence_length = sequence_length
self.lookahead = lookahead
self.nb_features = nb_features
self.nb_used_features = nb_used_features
pcm_chunk_size = self.frame_size*self.sequence_length
self.data = np.memmap(signal_file, dtype='int16', mode='r')
#self.data = self.data[1::2]
self.nb_sequences = len(self.data)//(pcm_chunk_size)-1
self.data = self.data[(4-self.lookahead)*self.frame_size:]
self.data = self.data[:self.nb_sequences*pcm_chunk_size]
self.data = np.reshape(self.data, (self.nb_sequences, pcm_chunk_size))
self.features = np.reshape(np.memmap(feature_file, dtype='float32', mode='r'), (-1, nb_features))
sizeof = self.features.strides[-1]
self.features = np.lib.stride_tricks.as_strided(self.features, shape=(self.nb_sequences, self.sequence_length+4, nb_features),
strides=(self.sequence_length*self.nb_features*sizeof, self.nb_features*sizeof, sizeof))
self.periods = np.round(50*self.features[:,:,self.nb_used_features-2]+100).astype('int')
self.lpc = self.features[:, :, self.nb_used_features:]
self.features = self.features[:, :, :self.nb_used_features]
print("lpc_size:", self.lpc.shape)
def __len__(self):
return self.nb_sequences
def __getitem__(self, index):
features = self.features[index, :, :].copy()
if self.lookahead != 0:
lpc = self.lpc[index, 4-self.lookahead:-self.lookahead, :].copy()
else:
lpc = self.lpc[index, 4:, :].copy()
data = self.data[index, :].copy().astype(np.float32) / 2**15
periods = self.periods[index, :].copy()
return features, periods, data, lpc

260
dnn/torch/fargan/fargan.py Normal file
View file

@ -0,0 +1,260 @@
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import filters
from torch.nn.utils import weight_norm
Fs = 16000
fid_dict = {}
def dump_signal(x, filename):
return
if filename in fid_dict:
fid = fid_dict[filename]
else:
fid = open(filename, "w")
fid_dict[filename] = fid
x = x.detach().numpy().astype('float32')
x.tofile(fid)
def sig_l1(y_true, y_pred):
return torch.mean(abs(y_true-y_pred))/torch.mean(abs(y_true))
def sig_loss(y_true, y_pred):
t = y_true/(1e-15+torch.norm(y_true, dim=-1, p=2, keepdim=True))
p = y_pred/(1e-15+torch.norm(y_pred, dim=-1, p=2, keepdim=True))
return torch.mean(1.-torch.sum(p*t, dim=-1))
def analysis_filter(x, lpc, nb_subframes=4, subframe_size=40, gamma=.9):
device = x.device
batch_size = lpc.size(0)
nb_frames = lpc.shape[1]
sig = torch.zeros(batch_size, subframe_size+16, device=device)
x = torch.reshape(x, (batch_size, nb_frames*nb_subframes, subframe_size))
out = torch.zeros((batch_size, 0), device=device)
if gamma is not None:
bw = gamma**(torch.arange(1, 17, device=device))
lpc = lpc*bw[None,None,:]
ones = torch.ones((*(lpc.shape[:-1]), 1), device=device)
zeros = torch.zeros((*(lpc.shape[:-1]), subframe_size-1), device=device)
a = torch.cat([ones, lpc], -1)
a_big = torch.cat([a, zeros], -1)
fir_mat_big = filters.toeplitz_from_filter(a_big)
#print(a_big[:,0,:])
for n in range(nb_frames):
for k in range(nb_subframes):
sig = torch.cat([sig[:,subframe_size:], x[:,n*nb_subframes + k, :]], 1)
exc = torch.bmm(fir_mat_big[:,n,:,:], sig[:,:,None])
out = torch.cat([out, exc[:,-subframe_size:,0]], 1)
return out
# weight initialization and clipping
def init_weights(module):
if isinstance(module, nn.GRU):
for p in module.named_parameters():
if p[0].startswith('weight_hh_'):
nn.init.orthogonal_(p[1])
def gen_phase_embedding(periods, frame_size):
device = periods.device
batch_size = periods.size(0)
nb_frames = periods.size(1)
w0 = 2*torch.pi/periods
w0_shift = torch.cat([2*torch.pi*torch.rand((batch_size, 1), device=device)/frame_size, w0[:,:-1]], 1)
cum_phase = frame_size*torch.cumsum(w0_shift, 1)
fine_phase = w0[:,:,None]*torch.broadcast_to(torch.arange(frame_size, device=device), (batch_size, nb_frames, frame_size))
embed = torch.unsqueeze(cum_phase, 2) + fine_phase
embed = torch.reshape(embed, (batch_size, -1))
return torch.cos(embed), torch.sin(embed)
class GLU(nn.Module):
def __init__(self, feat_size):
super(GLU, self).__init__()
torch.manual_seed(5)
self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x):
out = x * torch.sigmoid(self.gate(x))
return out
class FARGANCond(nn.Module):
def __init__(self, feature_dim=20, cond_size=256, pembed_dims=64):
super(FARGANCond, self).__init__()
self.feature_dim = feature_dim
self.cond_size = cond_size
self.pembed = nn.Embedding(256, pembed_dims)
self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, self.cond_size, bias=False)
self.fconv1 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False)
self.fconv2 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False)
self.fdense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
self.apply(init_weights)
def forward(self, features, period):
p = self.pembed(period)
features = torch.cat((features, p), -1)
tmp = torch.tanh(self.fdense1(features))
tmp = tmp.permute(0, 2, 1)
tmp = torch.tanh(self.fconv1(tmp))
tmp = torch.tanh(self.fconv2(tmp))
tmp = tmp.permute(0, 2, 1)
tmp = torch.tanh(self.fdense2(tmp))
return tmp
class FARGANSub(nn.Module):
def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256, passthrough_size=0, has_gain=False):
super(FARGANSub, self).__init__()
self.subframe_size = subframe_size
self.nb_subframes = nb_subframes
self.cond_size = cond_size
self.has_gain = has_gain
self.passthrough_size = passthrough_size
print("has_gain:", self.has_gain)
print("passthrough_size:", self.passthrough_size)
gain_param = 1 if self.has_gain else 0
self.sig_dense1 = nn.Linear(3*self.subframe_size+self.passthrough_size+self.cond_size+gain_param, self.cond_size, bias=False)
self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
self.gru3 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
self.dense1_glu = GLU(self.cond_size)
self.dense2_glu = GLU(self.cond_size)
self.gru1_glu = GLU(self.cond_size)
self.gru2_glu = GLU(self.cond_size)
self.gru3_glu = GLU(self.cond_size)
self.sig_dense_out = nn.Linear(self.cond_size, self.subframe_size+self.passthrough_size, bias=False)
if self.has_gain:
self.gain_dense_out = nn.Linear(self.cond_size, 1)
self.apply(init_weights)
def forward(self, cond, prev, exc_mem, phase, period, states):
device = exc_mem.device
#print(cond.shape, prev.shape)
dump_signal(prev, 'prev_in.f32')
idx = 256-torch.maximum(torch.tensor(self.subframe_size, device=device), period[:,None])
rng = torch.arange(self.subframe_size, device=device)
idx = idx + rng[None,:]
prev = torch.gather(exc_mem, 1, idx)
#prev = prev*0
dump_signal(prev, 'pitch_exc.f32')
dump_signal(exc_mem, 'exc_mem.f32')
if self.has_gain:
gain = torch.norm(prev, dim=1, p=2, keepdim=True)
prev = prev/(1e-5+gain)
prev = torch.cat([prev, torch.log(1e-5+gain)], 1)
passthrough = states[3]
tmp = torch.cat((cond, prev, passthrough, phase), 1)
tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
tmp = self.dense2_glu(torch.tanh(self.sig_dense2(tmp)))
gru1_state = self.gru1(tmp, states[0])
gru2_state = self.gru2(self.gru1_glu(gru1_state), states[1])
gru3_state = self.gru3(self.gru2_glu(gru2_state), states[2])
gru3_out = self.gru3_glu(gru3_state)
sig_out = torch.tanh(self.sig_dense_out(gru3_out))
if self.passthrough_size != 0:
passthrough = sig_out[:,self.subframe_size:]
sig_out = sig_out[:,:self.subframe_size]
if self.has_gain:
out_gain = torch.exp(self.gain_dense_out(gru3_out))
sig_out = sig_out * out_gain
dump_signal(sig_out, 'exc_out.f32')
exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1)
dump_signal(sig_out, 'sig_out.f32')
return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, passthrough)
class FARGAN(nn.Module):
def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None):
super(FARGAN, self).__init__()
self.subframe_size = subframe_size
self.nb_subframes = nb_subframes
self.frame_size = self.subframe_size*self.nb_subframes
self.feature_dim = feature_dim
self.cond_size = cond_size
self.has_gain = has_gain
self.passthrough_size = passthrough_size
self.cond_net = FARGANCond(feature_dim=feature_dim, cond_size=cond_size)
self.sig_net = FARGANSub(subframe_size=subframe_size, nb_subframes=nb_subframes, cond_size=cond_size, has_gain=has_gain, passthrough_size=passthrough_size)
def forward(self, features, period, nb_frames, pre=None, states=None):
device = features.device
batch_size = features.size(0)
phase_real, phase_imag = gen_phase_embedding(period[:, 3:-1], self.frame_size)
#np.round(32000*phase.detach().numpy()).astype('int16').tofile('phase.sw')
prev = torch.zeros(batch_size, self.subframe_size, device=device)
exc_mem = torch.zeros(batch_size, 256, device=device)
nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0
if states is None:
states = (
torch.zeros(batch_size, self.cond_size, device=device),
torch.zeros(batch_size, self.cond_size, device=device),
torch.zeros(batch_size, self.cond_size, device=device),
torch.zeros(batch_size, self.passthrough_size, device=device)
)
sig = torch.zeros((batch_size, 0), device=device)
cond = self.cond_net(features, period)
passthrough = torch.zeros(batch_size, self.passthrough_size, device=device)
for n in range(nb_frames+nb_pre_frames):
for k in range(self.nb_subframes):
pos = n*self.frame_size + k*self.subframe_size
preal = phase_real[:, pos:pos+self.subframe_size]
pimag = phase_imag[:, pos:pos+self.subframe_size]
phase = torch.cat([preal, pimag], 1)
#print("now: ", preal.shape, prev.shape, sig_in.shape)
pitch = period[:, 3+n]
out, exc_mem, states = self.sig_net(cond[:, n, :], prev, exc_mem, phase, pitch, states)
if n < nb_pre_frames:
out = pre[:, pos:pos+self.subframe_size]
exc_mem[:,-self.subframe_size:] = out
else:
sig = torch.cat([sig, out], 1)
prev = out
states = [s.detach() for s in states]
return sig, states

View file

@ -0,0 +1,46 @@
import torch
from torch import nn
import torch.nn.functional as F
import math
def toeplitz_from_filter(a):
device = a.device
L = a.size(-1)
size0 = (*(a.shape[:-1]), L, L+1)
size = (*(a.shape[:-1]), L, L)
rnge = torch.arange(0, L, dtype=torch.int64, device=device)
z = torch.tensor(0, device=device)
idx = torch.maximum(rnge[:,None] - rnge[None,:] + 1, z)
a = torch.cat([a[...,:1]*0, a], -1)
#print(a)
a = a[...,None,:]
#print(idx)
a = torch.broadcast_to(a, size0)
idx = torch.broadcast_to(idx, size)
#print(idx)
return torch.gather(a, -1, idx)
def filter_iir_response(a, N):
device = a.device
L = a.size(-1)
ar = a.flip(dims=(2,))
size = (*(a.shape[:-1]), N)
R = torch.zeros(size, device=device)
R[:,:,0] = torch.ones((a.shape[:-1]), device=device)
for i in range(1, L):
R[:,:,i] = - torch.sum(ar[:,:,L-i-1:-1] * R[:,:,:i], axis=-1)
#R[:,:,i] = - torch.einsum('ijk,ijk->ij', ar[:,:,L-i-1:-1], R[:,:,:i])
for i in range(L, N):
R[:,:,i] = - torch.sum(ar[:,:,:-1] * R[:,:,i-L+1:i], axis=-1)
#R[:,:,i] = - torch.einsum('ijk,ijk->ij', ar[:,:,:-1], R[:,:,i-L+1:i])
return R
if __name__ == '__main__':
#a = torch.tensor([ [[1, -.9, 0.02], [1, -.8, .01]], [[1, .9, 0], [1, .8, 0]]])
a = torch.tensor([ [[1, -.9, 0.02], [1, -.8, .01]]])
A = toeplitz_from_filter(a)
#print(A)
R = filter_iir_response(a, 5)
RA = toeplitz_from_filter(R)
print(RA)

View file

@ -0,0 +1,184 @@
"""STFT-based Loss modules."""
import torch
import torch.nn.functional as F
import numpy as np
import torchaudio
def stft(x, fft_size, hop_size, win_length, window):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x (Tensor): Input signal tensor (B, T).
fft_size (int): FFT size.
hop_size (int): Hop size.
win_length (int): Window length.
window (str): Window function type.
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
#x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=False)
#real = x_stft[..., 0]
#imag = x_stft[..., 1]
# (kan-bayashi): clamp is needed to avoid nan or inf
#return torchaudio.functional.amplitude_to_DB(torch.abs(x_stft),db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)
#return torch.clamp(torch.abs(x_stft), min=1e-7)
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
return torch.clamp(torch.abs(x_stft), min=1e-7)
class SpectralConvergenceLoss(torch.nn.Module):
"""Spectral convergence loss module."""
def __init__(self):
"""Initilize spectral convergence loss module."""
super(SpectralConvergenceLoss, self).__init__()
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Spectral convergence loss value.
"""
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
class LogSTFTMagnitudeLoss(torch.nn.Module):
"""Log STFT magnitude loss module."""
def __init__(self):
"""Initilize los STFT magnitude loss module."""
super(LogSTFTMagnitudeLoss, self).__init__()
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Log STFT magnitude loss value.
"""
#F.l1_loss(torch.sqrt(y_mag), torch.sqrt(x_mag)) +
#F.l1_loss(torchaudio.functional.amplitude_to_DB(y_mag,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80),\
#torchaudio.functional.amplitude_to_DB(x_mag,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80))
#y_mag[:,:y_mag.size(1)//2,:] = y_mag[:,:y_mag.size(1)//2,:] *0.0
#return F.l1_loss(torch.log(y_mag) + torch.sqrt(y_mag), torch.log(x_mag) + torch.sqrt(x_mag))
#return F.l1_loss(y_mag, x_mag)
error_loss = F.l1_loss(y, x) #+ F.l1_loss(torch.sqrt(y), torch.sqrt(x))#F.l1_loss(torch.log(y), torch.log(x))#
#x = torch.log(x)
#y = torch.log(y)
#x = x.permute(0,2,1).contiguous()
#y = y.permute(0,2,1).contiguous()
'''mean_x = torch.mean(x, dim=1, keepdim=True)
mean_y = torch.mean(y, dim=1, keepdim=True)
var_x = torch.var(x, dim=1, keepdim=True)
var_y = torch.var(y, dim=1, keepdim=True)
std_x = torch.std(x, dim=1, keepdim=True)
std_y = torch.std(y, dim=1, keepdim=True)
x_minus_mean = x - mean_x
y_minus_mean = y - mean_y
pearson_corr = torch.sum(x_minus_mean * y_minus_mean, dim=1, keepdim=True) / \
(torch.sqrt(torch.sum(x_minus_mean ** 2, dim=1, keepdim=True) + 1e-7) * \
torch.sqrt(torch.sum(y_minus_mean ** 2, dim=1, keepdim=True) + 1e-7))
numerator = 2.0 * pearson_corr * std_x * std_y
denominator = var_x + var_y + (mean_y - mean_x)**2
ccc = numerator/denominator
ccc_loss = F.l1_loss(1.0 - ccc, torch.zeros_like(ccc))'''
return error_loss #+ ccc_loss#+ ccc_loss
class STFTLoss(torch.nn.Module):
"""STFT loss module."""
def __init__(self, device, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
"""Initialize STFT loss module."""
super(STFTLoss, self).__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
self.window = getattr(torch, window)(win_length).to(device)
self.spectral_convergenge_loss = SpectralConvergenceLoss()
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T).
y (Tensor): Groundtruth signal (B, T).
Returns:
Tensor: Spectral convergence loss value.
Tensor: Log STFT magnitude loss value.
"""
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
return sc_loss, mag_loss
class MultiResolutionSTFTLoss(torch.nn.Module):
def __init__(self,
device,
fft_sizes=[2048, 1024, 512, 256, 128, 64],
hop_sizes=[512, 256, 128, 64, 32, 16],
win_lengths=[2048, 1024, 512, 256, 128, 64],
window="hann_window"):
'''def __init__(self,
device,
fft_sizes=[2048, 1024, 512, 256, 128, 64],
hop_sizes=[256, 128, 64, 32, 16, 8],
win_lengths=[1024, 512, 256, 128, 64, 32],
window="hann_window"):'''
'''def __init__(self,
device,
fft_sizes=[2560, 1280, 640, 320, 160, 80],
hop_sizes=[640, 320, 160, 80, 40, 20],
win_lengths=[2560, 1280, 640, 320, 160, 80],
window="hann_window"):'''
super(MultiResolutionSTFTLoss, self).__init__()
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
self.stft_losses = torch.nn.ModuleList()
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
self.stft_losses += [STFTLoss(device, fs, ss, wl, window)]
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T).
y (Tensor): Groundtruth signal (B, T).
Returns:
Tensor: Multi resolution spectral convergence loss value.
Tensor: Multi resolution log STFT magnitude loss value.
"""
sc_loss = 0.0
mag_loss = 0.0
for f in self.stft_losses:
sc_l, mag_l = f(x, y)
sc_loss += sc_l
#mag_loss += mag_l
sc_loss /= len(self.stft_losses)
mag_loss /= len(self.stft_losses)
return sc_loss #mag_loss #+

View file

@ -0,0 +1,107 @@
import os
import argparse
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import tqdm
import fargan
from dataset import FARGANDataset
nb_features = 36
nb_used_features = 20
parser = argparse.ArgumentParser()
parser.add_argument('model', type=str, help='CELPNet model')
parser.add_argument('features', type=str, help='path to feature file in .f32 format')
parser.add_argument('output', type=str, help='path to output file (16-bit PCM)')
parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
model_group = parser.add_argument_group(title="model parameters")
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
args = parser.parse_args()
if args.cuda_visible_devices != None:
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
features_file = args.features
signal_file = args.output
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
checkpoint = torch.load(args.model, map_location='cpu')
model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
model.load_state_dict(checkpoint['state_dict'], strict=False)
features = np.reshape(np.memmap(features_file, dtype='float32', mode='r'), (1, -1, nb_features))
lpc = features[:,4-1:-1,nb_used_features:]
features = features[:, :, :nb_used_features]
periods = np.round(50*features[:,:,nb_used_features-2]+100).astype('int')
nb_frames = features.shape[1]
#nb_frames = 1000
gamma = checkpoint['model_kwargs']['gamma']
def lpc_synthesis_one_frame(frame, filt, buffer, weighting_vector=np.ones(16)):
out = np.zeros_like(frame)
filt = np.flip(filt)
inp = frame[:]
for i in range(0, inp.shape[0]):
s = inp[i] - np.dot(buffer*weighting_vector, filt)
buffer[0] = s
buffer = np.roll(buffer, -1)
out[i] = s
return out
def inverse_perceptual_weighting (pw_signal, filters, weighting_vector):
#inverse perceptual weighting= H_preemph / W(z/gamma)
signal = np.zeros_like(pw_signal)
buffer = np.zeros(16)
num_frames = pw_signal.shape[0] //160
assert num_frames == filters.shape[0]
for frame_idx in range(0, num_frames):
in_frame = pw_signal[frame_idx*160: (frame_idx+1)*160][:]
out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer, weighting_vector)
signal[frame_idx*160: (frame_idx+1)*160] = out_sig_frame[:]
buffer[:] = out_sig_frame[-16:]
return signal
if __name__ == '__main__':
model.to(device)
features = torch.tensor(features).to(device)
#lpc = torch.tensor(lpc).to(device)
periods = torch.tensor(periods).to(device)
sig, _ = model(features, periods, nb_frames - 4)
weighting_vector = np.array([gamma**i for i in range(16,0,-1)])
sig = sig.detach().numpy().flatten()
sig = inverse_perceptual_weighting(sig, lpc[0,:,:], weighting_vector)
pcm = np.round(32768*np.clip(sig, a_max=.99, a_min=-.99)).astype('int16')
pcm.tofile(signal_file)

View file

@ -0,0 +1,155 @@
import os
import argparse
import random
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import tqdm
import fargan
from dataset import FARGANDataset
from stft_loss import *
parser = argparse.ArgumentParser()
parser.add_argument('features', type=str, help='path to feature file in .f32 format')
parser.add_argument('signal', type=str, help='path to signal file in .s16 format')
parser.add_argument('output', type=str, help='path to output folder')
parser.add_argument('--suffix', type=str, help="model name suffix", default="")
parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
model_group = parser.add_argument_group(title="model parameters")
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
model_group.add_argument('--has-gain', action='store_true', help="use gain-shape network")
model_group.add_argument('--passthrough-size', type=int, help="state passing through in addition to audio, default: 0", default=0)
model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9)
training_group = parser.add_argument_group(title="training parameters")
training_group.add_argument('--batch-size', type=int, help="batch size, default: 512", default=512)
training_group.add_argument('--lr', type=float, help='learning rate, default: 1e-3', default=1e-3)
training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 20', default=20)
training_group.add_argument('--sequence-length', type=int, help='sequence length, default: 15', default=15)
training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 1e-4', default=1e-4)
training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
args = parser.parse_args()
if args.cuda_visible_devices != None:
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
# checkpoints
checkpoint_dir = os.path.join(args.output, 'checkpoints')
checkpoint = dict()
os.makedirs(checkpoint_dir, exist_ok=True)
# training parameters
batch_size = args.batch_size
lr = args.lr
epochs = args.epochs
sequence_length = args.sequence_length
lr_decay = args.lr_decay
adam_betas = [0.9, 0.99]
adam_eps = 1e-8
features_file = args.features
signal_file = args.signal
# model parameters
cond_size = args.cond_size
checkpoint['batch_size'] = batch_size
checkpoint['lr'] = lr
checkpoint['lr_decay'] = lr_decay
checkpoint['epochs'] = epochs
checkpoint['sequence_length'] = sequence_length
checkpoint['adam_betas'] = adam_betas
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
checkpoint['model_args'] = ()
checkpoint['model_kwargs'] = {'cond_size': cond_size, 'has_gain': args.has_gain, 'passthrough_size': args.passthrough_size, 'gamma': args.gamma}
print(checkpoint['model_kwargs'])
model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
#model = fargan.FARGAN()
#model = nn.DataParallel(model)
if type(args.initial_checkpoint) != type(None):
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'], strict=False)
checkpoint['state_dict'] = model.state_dict()
dataset = FARGANDataset(features_file, signal_file, sequence_length=sequence_length)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps)
# learning rate scheduler
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
states = None
spect_loss = MultiResolutionSTFTLoss(device).to(device)
if __name__ == '__main__':
model.to(device)
for epoch in range(1, epochs + 1):
running_specc = 0
running_cont_loss = 0
running_loss = 0
print(f"training epoch {epoch}...")
with tqdm.tqdm(dataloader, unit='batch') as tepoch:
for i, (features, periods, target, lpc) in enumerate(tepoch):
optimizer.zero_grad()
features = features.to(device)
lpc = lpc.to(device)
periods = periods.to(device)
target = target.to(device)
target = fargan.analysis_filter(target, lpc[:,:,:], gamma=args.gamma)
#nb_pre = random.randrange(1, 6)
nb_pre = 2
pre = target[:, :nb_pre*160]
sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None)
sig = torch.cat([pre, sig], -1)
cont_loss = fargan.sig_l1(target[:, nb_pre*160:nb_pre*160+80], sig[:, nb_pre*160:nb_pre*160+80])
specc_loss = spect_loss(sig, target.detach())
loss = .2*cont_loss + specc_loss
loss.backward()
optimizer.step()
#model.clip_weights()
scheduler.step()
running_specc += specc_loss.detach().cpu().item()
running_cont_loss += cont_loss.detach().cpu().item()
running_loss += loss.detach().cpu().item()
tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}",
cont_loss=f"{running_cont_loss/(i+1):8.5f}",
specc=f"{running_specc/(i+1):8.5f}",
)
# save checkpoint
checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_{epoch}.pth')
checkpoint['state_dict'] = model.state_dict()
checkpoint['loss'] = running_loss / len(dataloader)
checkpoint['epoch'] = epoch
torch.save(checkpoint, checkpoint_path)