mirror of
https://github.com/xiph/opus.git
synced 2025-06-06 15:30:48 +00:00
FARGAN initial commit in Opus
Copied/adapted from LPCNet repo
This commit is contained in:
parent
4f4b624209
commit
1b13f6313e
6 changed files with 804 additions and 0 deletions
52
dnn/torch/fargan/dataset.py
Normal file
52
dnn/torch/fargan/dataset.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class FARGANDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
feature_file,
|
||||||
|
signal_file,
|
||||||
|
frame_size=160,
|
||||||
|
sequence_length=15,
|
||||||
|
lookahead=1,
|
||||||
|
nb_used_features=20,
|
||||||
|
nb_features=36):
|
||||||
|
|
||||||
|
self.frame_size = frame_size
|
||||||
|
self.sequence_length = sequence_length
|
||||||
|
self.lookahead = lookahead
|
||||||
|
self.nb_features = nb_features
|
||||||
|
self.nb_used_features = nb_used_features
|
||||||
|
pcm_chunk_size = self.frame_size*self.sequence_length
|
||||||
|
|
||||||
|
self.data = np.memmap(signal_file, dtype='int16', mode='r')
|
||||||
|
#self.data = self.data[1::2]
|
||||||
|
self.nb_sequences = len(self.data)//(pcm_chunk_size)-1
|
||||||
|
self.data = self.data[(4-self.lookahead)*self.frame_size:]
|
||||||
|
self.data = self.data[:self.nb_sequences*pcm_chunk_size]
|
||||||
|
|
||||||
|
|
||||||
|
self.data = np.reshape(self.data, (self.nb_sequences, pcm_chunk_size))
|
||||||
|
|
||||||
|
self.features = np.reshape(np.memmap(feature_file, dtype='float32', mode='r'), (-1, nb_features))
|
||||||
|
sizeof = self.features.strides[-1]
|
||||||
|
self.features = np.lib.stride_tricks.as_strided(self.features, shape=(self.nb_sequences, self.sequence_length+4, nb_features),
|
||||||
|
strides=(self.sequence_length*self.nb_features*sizeof, self.nb_features*sizeof, sizeof))
|
||||||
|
self.periods = np.round(50*self.features[:,:,self.nb_used_features-2]+100).astype('int')
|
||||||
|
|
||||||
|
self.lpc = self.features[:, :, self.nb_used_features:]
|
||||||
|
self.features = self.features[:, :, :self.nb_used_features]
|
||||||
|
print("lpc_size:", self.lpc.shape)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.nb_sequences
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
features = self.features[index, :, :].copy()
|
||||||
|
if self.lookahead != 0:
|
||||||
|
lpc = self.lpc[index, 4-self.lookahead:-self.lookahead, :].copy()
|
||||||
|
else:
|
||||||
|
lpc = self.lpc[index, 4:, :].copy()
|
||||||
|
data = self.data[index, :].copy().astype(np.float32) / 2**15
|
||||||
|
periods = self.periods[index, :].copy()
|
||||||
|
|
||||||
|
return features, periods, data, lpc
|
260
dnn/torch/fargan/fargan.py
Normal file
260
dnn/torch/fargan/fargan.py
Normal file
|
@ -0,0 +1,260 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import filters
|
||||||
|
from torch.nn.utils import weight_norm
|
||||||
|
|
||||||
|
Fs = 16000
|
||||||
|
|
||||||
|
fid_dict = {}
|
||||||
|
def dump_signal(x, filename):
|
||||||
|
return
|
||||||
|
if filename in fid_dict:
|
||||||
|
fid = fid_dict[filename]
|
||||||
|
else:
|
||||||
|
fid = open(filename, "w")
|
||||||
|
fid_dict[filename] = fid
|
||||||
|
x = x.detach().numpy().astype('float32')
|
||||||
|
x.tofile(fid)
|
||||||
|
|
||||||
|
|
||||||
|
def sig_l1(y_true, y_pred):
|
||||||
|
return torch.mean(abs(y_true-y_pred))/torch.mean(abs(y_true))
|
||||||
|
|
||||||
|
def sig_loss(y_true, y_pred):
|
||||||
|
t = y_true/(1e-15+torch.norm(y_true, dim=-1, p=2, keepdim=True))
|
||||||
|
p = y_pred/(1e-15+torch.norm(y_pred, dim=-1, p=2, keepdim=True))
|
||||||
|
return torch.mean(1.-torch.sum(p*t, dim=-1))
|
||||||
|
|
||||||
|
|
||||||
|
def analysis_filter(x, lpc, nb_subframes=4, subframe_size=40, gamma=.9):
|
||||||
|
device = x.device
|
||||||
|
batch_size = lpc.size(0)
|
||||||
|
|
||||||
|
nb_frames = lpc.shape[1]
|
||||||
|
|
||||||
|
|
||||||
|
sig = torch.zeros(batch_size, subframe_size+16, device=device)
|
||||||
|
x = torch.reshape(x, (batch_size, nb_frames*nb_subframes, subframe_size))
|
||||||
|
out = torch.zeros((batch_size, 0), device=device)
|
||||||
|
|
||||||
|
if gamma is not None:
|
||||||
|
bw = gamma**(torch.arange(1, 17, device=device))
|
||||||
|
lpc = lpc*bw[None,None,:]
|
||||||
|
ones = torch.ones((*(lpc.shape[:-1]), 1), device=device)
|
||||||
|
zeros = torch.zeros((*(lpc.shape[:-1]), subframe_size-1), device=device)
|
||||||
|
a = torch.cat([ones, lpc], -1)
|
||||||
|
a_big = torch.cat([a, zeros], -1)
|
||||||
|
fir_mat_big = filters.toeplitz_from_filter(a_big)
|
||||||
|
|
||||||
|
#print(a_big[:,0,:])
|
||||||
|
for n in range(nb_frames):
|
||||||
|
for k in range(nb_subframes):
|
||||||
|
|
||||||
|
sig = torch.cat([sig[:,subframe_size:], x[:,n*nb_subframes + k, :]], 1)
|
||||||
|
exc = torch.bmm(fir_mat_big[:,n,:,:], sig[:,:,None])
|
||||||
|
out = torch.cat([out, exc[:,-subframe_size:,0]], 1)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# weight initialization and clipping
|
||||||
|
def init_weights(module):
|
||||||
|
if isinstance(module, nn.GRU):
|
||||||
|
for p in module.named_parameters():
|
||||||
|
if p[0].startswith('weight_hh_'):
|
||||||
|
nn.init.orthogonal_(p[1])
|
||||||
|
|
||||||
|
def gen_phase_embedding(periods, frame_size):
|
||||||
|
device = periods.device
|
||||||
|
batch_size = periods.size(0)
|
||||||
|
nb_frames = periods.size(1)
|
||||||
|
w0 = 2*torch.pi/periods
|
||||||
|
w0_shift = torch.cat([2*torch.pi*torch.rand((batch_size, 1), device=device)/frame_size, w0[:,:-1]], 1)
|
||||||
|
cum_phase = frame_size*torch.cumsum(w0_shift, 1)
|
||||||
|
fine_phase = w0[:,:,None]*torch.broadcast_to(torch.arange(frame_size, device=device), (batch_size, nb_frames, frame_size))
|
||||||
|
embed = torch.unsqueeze(cum_phase, 2) + fine_phase
|
||||||
|
embed = torch.reshape(embed, (batch_size, -1))
|
||||||
|
return torch.cos(embed), torch.sin(embed)
|
||||||
|
|
||||||
|
class GLU(nn.Module):
|
||||||
|
def __init__(self, feat_size):
|
||||||
|
super(GLU, self).__init__()
|
||||||
|
|
||||||
|
torch.manual_seed(5)
|
||||||
|
|
||||||
|
self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
|
||||||
|
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||||
|
nn.init.orthogonal_(m.weight.data)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
out = x * torch.sigmoid(self.gate(x))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class FARGANCond(nn.Module):
|
||||||
|
def __init__(self, feature_dim=20, cond_size=256, pembed_dims=64):
|
||||||
|
super(FARGANCond, self).__init__()
|
||||||
|
|
||||||
|
self.feature_dim = feature_dim
|
||||||
|
self.cond_size = cond_size
|
||||||
|
|
||||||
|
self.pembed = nn.Embedding(256, pembed_dims)
|
||||||
|
self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, self.cond_size, bias=False)
|
||||||
|
self.fconv1 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False)
|
||||||
|
self.fconv2 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False)
|
||||||
|
self.fdense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
|
||||||
|
|
||||||
|
self.apply(init_weights)
|
||||||
|
|
||||||
|
def forward(self, features, period):
|
||||||
|
p = self.pembed(period)
|
||||||
|
features = torch.cat((features, p), -1)
|
||||||
|
tmp = torch.tanh(self.fdense1(features))
|
||||||
|
tmp = tmp.permute(0, 2, 1)
|
||||||
|
tmp = torch.tanh(self.fconv1(tmp))
|
||||||
|
tmp = torch.tanh(self.fconv2(tmp))
|
||||||
|
tmp = tmp.permute(0, 2, 1)
|
||||||
|
tmp = torch.tanh(self.fdense2(tmp))
|
||||||
|
return tmp
|
||||||
|
|
||||||
|
class FARGANSub(nn.Module):
|
||||||
|
def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256, passthrough_size=0, has_gain=False):
|
||||||
|
super(FARGANSub, self).__init__()
|
||||||
|
|
||||||
|
self.subframe_size = subframe_size
|
||||||
|
self.nb_subframes = nb_subframes
|
||||||
|
self.cond_size = cond_size
|
||||||
|
self.has_gain = has_gain
|
||||||
|
self.passthrough_size = passthrough_size
|
||||||
|
|
||||||
|
print("has_gain:", self.has_gain)
|
||||||
|
print("passthrough_size:", self.passthrough_size)
|
||||||
|
|
||||||
|
gain_param = 1 if self.has_gain else 0
|
||||||
|
|
||||||
|
self.sig_dense1 = nn.Linear(3*self.subframe_size+self.passthrough_size+self.cond_size+gain_param, self.cond_size, bias=False)
|
||||||
|
self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
|
||||||
|
self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
|
||||||
|
self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
|
||||||
|
self.gru3 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
|
||||||
|
|
||||||
|
self.dense1_glu = GLU(self.cond_size)
|
||||||
|
self.dense2_glu = GLU(self.cond_size)
|
||||||
|
self.gru1_glu = GLU(self.cond_size)
|
||||||
|
self.gru2_glu = GLU(self.cond_size)
|
||||||
|
self.gru3_glu = GLU(self.cond_size)
|
||||||
|
|
||||||
|
self.sig_dense_out = nn.Linear(self.cond_size, self.subframe_size+self.passthrough_size, bias=False)
|
||||||
|
if self.has_gain:
|
||||||
|
self.gain_dense_out = nn.Linear(self.cond_size, 1)
|
||||||
|
|
||||||
|
|
||||||
|
self.apply(init_weights)
|
||||||
|
|
||||||
|
def forward(self, cond, prev, exc_mem, phase, period, states):
|
||||||
|
device = exc_mem.device
|
||||||
|
#print(cond.shape, prev.shape)
|
||||||
|
|
||||||
|
dump_signal(prev, 'prev_in.f32')
|
||||||
|
|
||||||
|
idx = 256-torch.maximum(torch.tensor(self.subframe_size, device=device), period[:,None])
|
||||||
|
rng = torch.arange(self.subframe_size, device=device)
|
||||||
|
idx = idx + rng[None,:]
|
||||||
|
prev = torch.gather(exc_mem, 1, idx)
|
||||||
|
#prev = prev*0
|
||||||
|
dump_signal(prev, 'pitch_exc.f32')
|
||||||
|
dump_signal(exc_mem, 'exc_mem.f32')
|
||||||
|
if self.has_gain:
|
||||||
|
gain = torch.norm(prev, dim=1, p=2, keepdim=True)
|
||||||
|
prev = prev/(1e-5+gain)
|
||||||
|
prev = torch.cat([prev, torch.log(1e-5+gain)], 1)
|
||||||
|
|
||||||
|
passthrough = states[3]
|
||||||
|
tmp = torch.cat((cond, prev, passthrough, phase), 1)
|
||||||
|
|
||||||
|
tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
|
||||||
|
tmp = self.dense2_glu(torch.tanh(self.sig_dense2(tmp)))
|
||||||
|
gru1_state = self.gru1(tmp, states[0])
|
||||||
|
gru2_state = self.gru2(self.gru1_glu(gru1_state), states[1])
|
||||||
|
gru3_state = self.gru3(self.gru2_glu(gru2_state), states[2])
|
||||||
|
gru3_out = self.gru3_glu(gru3_state)
|
||||||
|
sig_out = torch.tanh(self.sig_dense_out(gru3_out))
|
||||||
|
if self.passthrough_size != 0:
|
||||||
|
passthrough = sig_out[:,self.subframe_size:]
|
||||||
|
sig_out = sig_out[:,:self.subframe_size]
|
||||||
|
if self.has_gain:
|
||||||
|
out_gain = torch.exp(self.gain_dense_out(gru3_out))
|
||||||
|
sig_out = sig_out * out_gain
|
||||||
|
dump_signal(sig_out, 'exc_out.f32')
|
||||||
|
exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1)
|
||||||
|
dump_signal(sig_out, 'sig_out.f32')
|
||||||
|
return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, passthrough)
|
||||||
|
|
||||||
|
class FARGAN(nn.Module):
|
||||||
|
def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None):
|
||||||
|
super(FARGAN, self).__init__()
|
||||||
|
|
||||||
|
self.subframe_size = subframe_size
|
||||||
|
self.nb_subframes = nb_subframes
|
||||||
|
self.frame_size = self.subframe_size*self.nb_subframes
|
||||||
|
self.feature_dim = feature_dim
|
||||||
|
self.cond_size = cond_size
|
||||||
|
self.has_gain = has_gain
|
||||||
|
self.passthrough_size = passthrough_size
|
||||||
|
|
||||||
|
self.cond_net = FARGANCond(feature_dim=feature_dim, cond_size=cond_size)
|
||||||
|
self.sig_net = FARGANSub(subframe_size=subframe_size, nb_subframes=nb_subframes, cond_size=cond_size, has_gain=has_gain, passthrough_size=passthrough_size)
|
||||||
|
|
||||||
|
def forward(self, features, period, nb_frames, pre=None, states=None):
|
||||||
|
device = features.device
|
||||||
|
batch_size = features.size(0)
|
||||||
|
|
||||||
|
phase_real, phase_imag = gen_phase_embedding(period[:, 3:-1], self.frame_size)
|
||||||
|
#np.round(32000*phase.detach().numpy()).astype('int16').tofile('phase.sw')
|
||||||
|
|
||||||
|
prev = torch.zeros(batch_size, self.subframe_size, device=device)
|
||||||
|
exc_mem = torch.zeros(batch_size, 256, device=device)
|
||||||
|
nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0
|
||||||
|
|
||||||
|
if states is None:
|
||||||
|
states = (
|
||||||
|
torch.zeros(batch_size, self.cond_size, device=device),
|
||||||
|
torch.zeros(batch_size, self.cond_size, device=device),
|
||||||
|
torch.zeros(batch_size, self.cond_size, device=device),
|
||||||
|
torch.zeros(batch_size, self.passthrough_size, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
sig = torch.zeros((batch_size, 0), device=device)
|
||||||
|
cond = self.cond_net(features, period)
|
||||||
|
passthrough = torch.zeros(batch_size, self.passthrough_size, device=device)
|
||||||
|
for n in range(nb_frames+nb_pre_frames):
|
||||||
|
for k in range(self.nb_subframes):
|
||||||
|
pos = n*self.frame_size + k*self.subframe_size
|
||||||
|
preal = phase_real[:, pos:pos+self.subframe_size]
|
||||||
|
pimag = phase_imag[:, pos:pos+self.subframe_size]
|
||||||
|
phase = torch.cat([preal, pimag], 1)
|
||||||
|
#print("now: ", preal.shape, prev.shape, sig_in.shape)
|
||||||
|
pitch = period[:, 3+n]
|
||||||
|
out, exc_mem, states = self.sig_net(cond[:, n, :], prev, exc_mem, phase, pitch, states)
|
||||||
|
|
||||||
|
if n < nb_pre_frames:
|
||||||
|
out = pre[:, pos:pos+self.subframe_size]
|
||||||
|
exc_mem[:,-self.subframe_size:] = out
|
||||||
|
else:
|
||||||
|
sig = torch.cat([sig, out], 1)
|
||||||
|
|
||||||
|
prev = out
|
||||||
|
states = [s.detach() for s in states]
|
||||||
|
return sig, states
|
||||||
|
|
46
dnn/torch/fargan/filters.py
Normal file
46
dnn/torch/fargan/filters.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
|
||||||
|
def toeplitz_from_filter(a):
|
||||||
|
device = a.device
|
||||||
|
L = a.size(-1)
|
||||||
|
size0 = (*(a.shape[:-1]), L, L+1)
|
||||||
|
size = (*(a.shape[:-1]), L, L)
|
||||||
|
rnge = torch.arange(0, L, dtype=torch.int64, device=device)
|
||||||
|
z = torch.tensor(0, device=device)
|
||||||
|
idx = torch.maximum(rnge[:,None] - rnge[None,:] + 1, z)
|
||||||
|
a = torch.cat([a[...,:1]*0, a], -1)
|
||||||
|
#print(a)
|
||||||
|
a = a[...,None,:]
|
||||||
|
#print(idx)
|
||||||
|
a = torch.broadcast_to(a, size0)
|
||||||
|
idx = torch.broadcast_to(idx, size)
|
||||||
|
#print(idx)
|
||||||
|
return torch.gather(a, -1, idx)
|
||||||
|
|
||||||
|
def filter_iir_response(a, N):
|
||||||
|
device = a.device
|
||||||
|
L = a.size(-1)
|
||||||
|
ar = a.flip(dims=(2,))
|
||||||
|
size = (*(a.shape[:-1]), N)
|
||||||
|
R = torch.zeros(size, device=device)
|
||||||
|
R[:,:,0] = torch.ones((a.shape[:-1]), device=device)
|
||||||
|
for i in range(1, L):
|
||||||
|
R[:,:,i] = - torch.sum(ar[:,:,L-i-1:-1] * R[:,:,:i], axis=-1)
|
||||||
|
#R[:,:,i] = - torch.einsum('ijk,ijk->ij', ar[:,:,L-i-1:-1], R[:,:,:i])
|
||||||
|
for i in range(L, N):
|
||||||
|
R[:,:,i] = - torch.sum(ar[:,:,:-1] * R[:,:,i-L+1:i], axis=-1)
|
||||||
|
#R[:,:,i] = - torch.einsum('ijk,ijk->ij', ar[:,:,:-1], R[:,:,i-L+1:i])
|
||||||
|
return R
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
#a = torch.tensor([ [[1, -.9, 0.02], [1, -.8, .01]], [[1, .9, 0], [1, .8, 0]]])
|
||||||
|
a = torch.tensor([ [[1, -.9, 0.02], [1, -.8, .01]]])
|
||||||
|
A = toeplitz_from_filter(a)
|
||||||
|
#print(A)
|
||||||
|
R = filter_iir_response(a, 5)
|
||||||
|
|
||||||
|
RA = toeplitz_from_filter(R)
|
||||||
|
print(RA)
|
184
dnn/torch/fargan/stft_loss.py
Normal file
184
dnn/torch/fargan/stft_loss.py
Normal file
|
@ -0,0 +1,184 @@
|
||||||
|
"""STFT-based Loss modules."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
|
||||||
|
def stft(x, fft_size, hop_size, win_length, window):
|
||||||
|
"""Perform STFT and convert to magnitude spectrogram.
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input signal tensor (B, T).
|
||||||
|
fft_size (int): FFT size.
|
||||||
|
hop_size (int): Hop size.
|
||||||
|
win_length (int): Window length.
|
||||||
|
window (str): Window function type.
|
||||||
|
Returns:
|
||||||
|
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
||||||
|
"""
|
||||||
|
|
||||||
|
#x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=False)
|
||||||
|
#real = x_stft[..., 0]
|
||||||
|
#imag = x_stft[..., 1]
|
||||||
|
|
||||||
|
# (kan-bayashi): clamp is needed to avoid nan or inf
|
||||||
|
#return torchaudio.functional.amplitude_to_DB(torch.abs(x_stft),db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)
|
||||||
|
#return torch.clamp(torch.abs(x_stft), min=1e-7)
|
||||||
|
|
||||||
|
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
|
||||||
|
return torch.clamp(torch.abs(x_stft), min=1e-7)
|
||||||
|
|
||||||
|
class SpectralConvergenceLoss(torch.nn.Module):
|
||||||
|
"""Spectral convergence loss module."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initilize spectral convergence loss module."""
|
||||||
|
super(SpectralConvergenceLoss, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x_mag, y_mag):
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Args:
|
||||||
|
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
||||||
|
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
||||||
|
Returns:
|
||||||
|
Tensor: Spectral convergence loss value.
|
||||||
|
"""
|
||||||
|
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
|
||||||
|
|
||||||
|
class LogSTFTMagnitudeLoss(torch.nn.Module):
|
||||||
|
"""Log STFT magnitude loss module."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initilize los STFT magnitude loss module."""
|
||||||
|
super(LogSTFTMagnitudeLoss, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Args:
|
||||||
|
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
||||||
|
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
||||||
|
Returns:
|
||||||
|
Tensor: Log STFT magnitude loss value.
|
||||||
|
"""
|
||||||
|
#F.l1_loss(torch.sqrt(y_mag), torch.sqrt(x_mag)) +
|
||||||
|
#F.l1_loss(torchaudio.functional.amplitude_to_DB(y_mag,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80),\
|
||||||
|
#torchaudio.functional.amplitude_to_DB(x_mag,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80))
|
||||||
|
|
||||||
|
#y_mag[:,:y_mag.size(1)//2,:] = y_mag[:,:y_mag.size(1)//2,:] *0.0
|
||||||
|
|
||||||
|
#return F.l1_loss(torch.log(y_mag) + torch.sqrt(y_mag), torch.log(x_mag) + torch.sqrt(x_mag))
|
||||||
|
|
||||||
|
#return F.l1_loss(y_mag, x_mag)
|
||||||
|
|
||||||
|
error_loss = F.l1_loss(y, x) #+ F.l1_loss(torch.sqrt(y), torch.sqrt(x))#F.l1_loss(torch.log(y), torch.log(x))#
|
||||||
|
|
||||||
|
#x = torch.log(x)
|
||||||
|
#y = torch.log(y)
|
||||||
|
#x = x.permute(0,2,1).contiguous()
|
||||||
|
#y = y.permute(0,2,1).contiguous()
|
||||||
|
|
||||||
|
'''mean_x = torch.mean(x, dim=1, keepdim=True)
|
||||||
|
mean_y = torch.mean(y, dim=1, keepdim=True)
|
||||||
|
|
||||||
|
var_x = torch.var(x, dim=1, keepdim=True)
|
||||||
|
var_y = torch.var(y, dim=1, keepdim=True)
|
||||||
|
|
||||||
|
std_x = torch.std(x, dim=1, keepdim=True)
|
||||||
|
std_y = torch.std(y, dim=1, keepdim=True)
|
||||||
|
|
||||||
|
x_minus_mean = x - mean_x
|
||||||
|
y_minus_mean = y - mean_y
|
||||||
|
|
||||||
|
pearson_corr = torch.sum(x_minus_mean * y_minus_mean, dim=1, keepdim=True) / \
|
||||||
|
(torch.sqrt(torch.sum(x_minus_mean ** 2, dim=1, keepdim=True) + 1e-7) * \
|
||||||
|
torch.sqrt(torch.sum(y_minus_mean ** 2, dim=1, keepdim=True) + 1e-7))
|
||||||
|
|
||||||
|
numerator = 2.0 * pearson_corr * std_x * std_y
|
||||||
|
denominator = var_x + var_y + (mean_y - mean_x)**2
|
||||||
|
|
||||||
|
ccc = numerator/denominator
|
||||||
|
|
||||||
|
ccc_loss = F.l1_loss(1.0 - ccc, torch.zeros_like(ccc))'''
|
||||||
|
|
||||||
|
return error_loss #+ ccc_loss#+ ccc_loss
|
||||||
|
|
||||||
|
|
||||||
|
class STFTLoss(torch.nn.Module):
|
||||||
|
"""STFT loss module."""
|
||||||
|
|
||||||
|
def __init__(self, device, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
|
||||||
|
"""Initialize STFT loss module."""
|
||||||
|
super(STFTLoss, self).__init__()
|
||||||
|
self.fft_size = fft_size
|
||||||
|
self.shift_size = shift_size
|
||||||
|
self.win_length = win_length
|
||||||
|
self.window = getattr(torch, window)(win_length).to(device)
|
||||||
|
self.spectral_convergenge_loss = SpectralConvergenceLoss()
|
||||||
|
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Args:
|
||||||
|
x (Tensor): Predicted signal (B, T).
|
||||||
|
y (Tensor): Groundtruth signal (B, T).
|
||||||
|
Returns:
|
||||||
|
Tensor: Spectral convergence loss value.
|
||||||
|
Tensor: Log STFT magnitude loss value.
|
||||||
|
"""
|
||||||
|
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
|
||||||
|
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
|
||||||
|
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
|
||||||
|
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
||||||
|
|
||||||
|
return sc_loss, mag_loss
|
||||||
|
|
||||||
|
|
||||||
|
class MultiResolutionSTFTLoss(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
device,
|
||||||
|
fft_sizes=[2048, 1024, 512, 256, 128, 64],
|
||||||
|
hop_sizes=[512, 256, 128, 64, 32, 16],
|
||||||
|
win_lengths=[2048, 1024, 512, 256, 128, 64],
|
||||||
|
window="hann_window"):
|
||||||
|
|
||||||
|
'''def __init__(self,
|
||||||
|
device,
|
||||||
|
fft_sizes=[2048, 1024, 512, 256, 128, 64],
|
||||||
|
hop_sizes=[256, 128, 64, 32, 16, 8],
|
||||||
|
win_lengths=[1024, 512, 256, 128, 64, 32],
|
||||||
|
window="hann_window"):'''
|
||||||
|
|
||||||
|
'''def __init__(self,
|
||||||
|
device,
|
||||||
|
fft_sizes=[2560, 1280, 640, 320, 160, 80],
|
||||||
|
hop_sizes=[640, 320, 160, 80, 40, 20],
|
||||||
|
win_lengths=[2560, 1280, 640, 320, 160, 80],
|
||||||
|
window="hann_window"):'''
|
||||||
|
|
||||||
|
super(MultiResolutionSTFTLoss, self).__init__()
|
||||||
|
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
|
||||||
|
self.stft_losses = torch.nn.ModuleList()
|
||||||
|
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
|
||||||
|
self.stft_losses += [STFTLoss(device, fs, ss, wl, window)]
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Args:
|
||||||
|
x (Tensor): Predicted signal (B, T).
|
||||||
|
y (Tensor): Groundtruth signal (B, T).
|
||||||
|
Returns:
|
||||||
|
Tensor: Multi resolution spectral convergence loss value.
|
||||||
|
Tensor: Multi resolution log STFT magnitude loss value.
|
||||||
|
"""
|
||||||
|
sc_loss = 0.0
|
||||||
|
mag_loss = 0.0
|
||||||
|
for f in self.stft_losses:
|
||||||
|
sc_l, mag_l = f(x, y)
|
||||||
|
sc_loss += sc_l
|
||||||
|
#mag_loss += mag_l
|
||||||
|
sc_loss /= len(self.stft_losses)
|
||||||
|
mag_loss /= len(self.stft_losses)
|
||||||
|
|
||||||
|
return sc_loss #mag_loss #+
|
107
dnn/torch/fargan/test_fargan.py
Normal file
107
dnn/torch/fargan/test_fargan.py
Normal file
|
@ -0,0 +1,107 @@
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
import fargan
|
||||||
|
from dataset import FARGANDataset
|
||||||
|
|
||||||
|
nb_features = 36
|
||||||
|
nb_used_features = 20
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument('model', type=str, help='CELPNet model')
|
||||||
|
parser.add_argument('features', type=str, help='path to feature file in .f32 format')
|
||||||
|
parser.add_argument('output', type=str, help='path to output file (16-bit PCM)')
|
||||||
|
|
||||||
|
parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
model_group = parser.add_argument_group(title="model parameters")
|
||||||
|
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.cuda_visible_devices != None:
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
|
||||||
|
|
||||||
|
|
||||||
|
features_file = args.features
|
||||||
|
signal_file = args.output
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
|
||||||
|
checkpoint = torch.load(args.model, map_location='cpu')
|
||||||
|
|
||||||
|
model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
|
||||||
|
|
||||||
|
|
||||||
|
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||||
|
|
||||||
|
features = np.reshape(np.memmap(features_file, dtype='float32', mode='r'), (1, -1, nb_features))
|
||||||
|
lpc = features[:,4-1:-1,nb_used_features:]
|
||||||
|
features = features[:, :, :nb_used_features]
|
||||||
|
periods = np.round(50*features[:,:,nb_used_features-2]+100).astype('int')
|
||||||
|
|
||||||
|
nb_frames = features.shape[1]
|
||||||
|
#nb_frames = 1000
|
||||||
|
gamma = checkpoint['model_kwargs']['gamma']
|
||||||
|
|
||||||
|
def lpc_synthesis_one_frame(frame, filt, buffer, weighting_vector=np.ones(16)):
|
||||||
|
|
||||||
|
out = np.zeros_like(frame)
|
||||||
|
filt = np.flip(filt)
|
||||||
|
|
||||||
|
inp = frame[:]
|
||||||
|
|
||||||
|
|
||||||
|
for i in range(0, inp.shape[0]):
|
||||||
|
|
||||||
|
s = inp[i] - np.dot(buffer*weighting_vector, filt)
|
||||||
|
|
||||||
|
buffer[0] = s
|
||||||
|
|
||||||
|
buffer = np.roll(buffer, -1)
|
||||||
|
|
||||||
|
out[i] = s
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def inverse_perceptual_weighting (pw_signal, filters, weighting_vector):
|
||||||
|
|
||||||
|
#inverse perceptual weighting= H_preemph / W(z/gamma)
|
||||||
|
|
||||||
|
signal = np.zeros_like(pw_signal)
|
||||||
|
buffer = np.zeros(16)
|
||||||
|
num_frames = pw_signal.shape[0] //160
|
||||||
|
assert num_frames == filters.shape[0]
|
||||||
|
for frame_idx in range(0, num_frames):
|
||||||
|
|
||||||
|
in_frame = pw_signal[frame_idx*160: (frame_idx+1)*160][:]
|
||||||
|
out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer, weighting_vector)
|
||||||
|
signal[frame_idx*160: (frame_idx+1)*160] = out_sig_frame[:]
|
||||||
|
buffer[:] = out_sig_frame[-16:]
|
||||||
|
return signal
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model.to(device)
|
||||||
|
features = torch.tensor(features).to(device)
|
||||||
|
#lpc = torch.tensor(lpc).to(device)
|
||||||
|
periods = torch.tensor(periods).to(device)
|
||||||
|
|
||||||
|
sig, _ = model(features, periods, nb_frames - 4)
|
||||||
|
weighting_vector = np.array([gamma**i for i in range(16,0,-1)])
|
||||||
|
sig = sig.detach().numpy().flatten()
|
||||||
|
sig = inverse_perceptual_weighting(sig, lpc[0,:,:], weighting_vector)
|
||||||
|
|
||||||
|
pcm = np.round(32768*np.clip(sig, a_max=.99, a_min=-.99)).astype('int16')
|
||||||
|
pcm.tofile(signal_file)
|
155
dnn/torch/fargan/train_fargan.py
Normal file
155
dnn/torch/fargan/train_fargan.py
Normal file
|
@ -0,0 +1,155 @@
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
import fargan
|
||||||
|
from dataset import FARGANDataset
|
||||||
|
from stft_loss import *
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument('features', type=str, help='path to feature file in .f32 format')
|
||||||
|
parser.add_argument('signal', type=str, help='path to signal file in .s16 format')
|
||||||
|
parser.add_argument('output', type=str, help='path to output folder')
|
||||||
|
|
||||||
|
parser.add_argument('--suffix', type=str, help="model name suffix", default="")
|
||||||
|
parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
model_group = parser.add_argument_group(title="model parameters")
|
||||||
|
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
|
||||||
|
model_group.add_argument('--has-gain', action='store_true', help="use gain-shape network")
|
||||||
|
model_group.add_argument('--passthrough-size', type=int, help="state passing through in addition to audio, default: 0", default=0)
|
||||||
|
model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9)
|
||||||
|
|
||||||
|
training_group = parser.add_argument_group(title="training parameters")
|
||||||
|
training_group.add_argument('--batch-size', type=int, help="batch size, default: 512", default=512)
|
||||||
|
training_group.add_argument('--lr', type=float, help='learning rate, default: 1e-3', default=1e-3)
|
||||||
|
training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 20', default=20)
|
||||||
|
training_group.add_argument('--sequence-length', type=int, help='sequence length, default: 15', default=15)
|
||||||
|
training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 1e-4', default=1e-4)
|
||||||
|
training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.cuda_visible_devices != None:
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
|
||||||
|
|
||||||
|
# checkpoints
|
||||||
|
checkpoint_dir = os.path.join(args.output, 'checkpoints')
|
||||||
|
checkpoint = dict()
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
# training parameters
|
||||||
|
batch_size = args.batch_size
|
||||||
|
lr = args.lr
|
||||||
|
epochs = args.epochs
|
||||||
|
sequence_length = args.sequence_length
|
||||||
|
lr_decay = args.lr_decay
|
||||||
|
|
||||||
|
adam_betas = [0.9, 0.99]
|
||||||
|
adam_eps = 1e-8
|
||||||
|
features_file = args.features
|
||||||
|
signal_file = args.signal
|
||||||
|
|
||||||
|
# model parameters
|
||||||
|
cond_size = args.cond_size
|
||||||
|
|
||||||
|
|
||||||
|
checkpoint['batch_size'] = batch_size
|
||||||
|
checkpoint['lr'] = lr
|
||||||
|
checkpoint['lr_decay'] = lr_decay
|
||||||
|
checkpoint['epochs'] = epochs
|
||||||
|
checkpoint['sequence_length'] = sequence_length
|
||||||
|
checkpoint['adam_betas'] = adam_betas
|
||||||
|
|
||||||
|
|
||||||
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
|
||||||
|
checkpoint['model_args'] = ()
|
||||||
|
checkpoint['model_kwargs'] = {'cond_size': cond_size, 'has_gain': args.has_gain, 'passthrough_size': args.passthrough_size, 'gamma': args.gamma}
|
||||||
|
print(checkpoint['model_kwargs'])
|
||||||
|
model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
|
||||||
|
|
||||||
|
#model = fargan.FARGAN()
|
||||||
|
#model = nn.DataParallel(model)
|
||||||
|
|
||||||
|
if type(args.initial_checkpoint) != type(None):
|
||||||
|
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
|
||||||
|
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||||
|
|
||||||
|
checkpoint['state_dict'] = model.state_dict()
|
||||||
|
|
||||||
|
|
||||||
|
dataset = FARGANDataset(features_file, signal_file, sequence_length=sequence_length)
|
||||||
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
|
||||||
|
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps)
|
||||||
|
|
||||||
|
|
||||||
|
# learning rate scheduler
|
||||||
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
|
||||||
|
|
||||||
|
states = None
|
||||||
|
|
||||||
|
spect_loss = MultiResolutionSTFTLoss(device).to(device)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
for epoch in range(1, epochs + 1):
|
||||||
|
|
||||||
|
running_specc = 0
|
||||||
|
running_cont_loss = 0
|
||||||
|
running_loss = 0
|
||||||
|
|
||||||
|
print(f"training epoch {epoch}...")
|
||||||
|
with tqdm.tqdm(dataloader, unit='batch') as tepoch:
|
||||||
|
for i, (features, periods, target, lpc) in enumerate(tepoch):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
features = features.to(device)
|
||||||
|
lpc = lpc.to(device)
|
||||||
|
periods = periods.to(device)
|
||||||
|
target = target.to(device)
|
||||||
|
target = fargan.analysis_filter(target, lpc[:,:,:], gamma=args.gamma)
|
||||||
|
|
||||||
|
#nb_pre = random.randrange(1, 6)
|
||||||
|
nb_pre = 2
|
||||||
|
pre = target[:, :nb_pre*160]
|
||||||
|
sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None)
|
||||||
|
sig = torch.cat([pre, sig], -1)
|
||||||
|
|
||||||
|
cont_loss = fargan.sig_l1(target[:, nb_pre*160:nb_pre*160+80], sig[:, nb_pre*160:nb_pre*160+80])
|
||||||
|
specc_loss = spect_loss(sig, target.detach())
|
||||||
|
loss = .2*cont_loss + specc_loss
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
#model.clip_weights()
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
running_specc += specc_loss.detach().cpu().item()
|
||||||
|
running_cont_loss += cont_loss.detach().cpu().item()
|
||||||
|
|
||||||
|
running_loss += loss.detach().cpu().item()
|
||||||
|
tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}",
|
||||||
|
cont_loss=f"{running_cont_loss/(i+1):8.5f}",
|
||||||
|
specc=f"{running_specc/(i+1):8.5f}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# save checkpoint
|
||||||
|
checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_{epoch}.pth')
|
||||||
|
checkpoint['state_dict'] = model.state_dict()
|
||||||
|
checkpoint['loss'] = running_loss / len(dataloader)
|
||||||
|
checkpoint['epoch'] = epoch
|
||||||
|
torch.save(checkpoint, checkpoint_path)
|
Loading…
Add table
Add a link
Reference in a new issue