added more enhancement stuff

Signed-off-by: Jan Buethe <jbuethe@amazon.de>
This commit is contained in:
Jan Buethe 2023-09-12 14:50:24 +02:00
parent 7b8ba143f1
commit 2f290d32ed
No known key found for this signature in database
GPG key ID: 9E32027A35B36314
24 changed files with 3511 additions and 108 deletions

View file

@ -0,0 +1,458 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import sys
import math as m
import random
import yaml
from tqdm import tqdm
try:
import git
has_git = True
except:
has_git = False
import torch
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F
from scipy.io import wavfile
import numpy as np
import pesq
from data import SilkEnhancementSet
from models import model_dict
from utils.silk_features import load_inference_data
from utils.misc import count_parameters, retain_grads, get_grad_norm, create_weights
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
parser = argparse.ArgumentParser()
parser.add_argument('setup', type=str, help='setup yaml file')
parser.add_argument('output', type=str, help='output path')
parser.add_argument('--device', type=str, help='compute device', default=None)
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
parser.add_argument('--testdata', type=str, help='path to features and signal for testing', default=None)
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
args = parser.parse_args()
torch.set_num_threads(4)
with open(args.setup, 'r') as f:
setup = yaml.load(f.read(), yaml.FullLoader)
checkpoint_prefix = 'checkpoint'
output_prefix = 'output'
setup_name = 'setup.yml'
output_file='out.txt'
# check model
if not 'name' in setup['model']:
print(f'warning: did not find model entry in setup, using default PitchPostFilter')
model_name = 'pitchpostfilter'
else:
model_name = setup['model']['name']
# prepare output folder
if os.path.exists(args.output):
print("warning: output folder exists")
reply = input('continue? (y/n): ')
while reply not in {'y', 'n'}:
reply = input('continue? (y/n): ')
if reply == 'n':
os._exit()
else:
os.makedirs(args.output, exist_ok=True)
checkpoint_dir = os.path.join(args.output, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
# add repo info to setup
if has_git:
working_dir = os.path.split(__file__)[0]
try:
repo = git.Repo(working_dir)
setup['repo'] = dict()
hash = repo.head.object.hexsha
urls = list(repo.remote().urls)
is_dirty = repo.is_dirty()
if is_dirty:
print("warning: repo is dirty")
setup['repo']['hash'] = hash
setup['repo']['urls'] = urls
setup['repo']['dirty'] = is_dirty
except:
has_git = False
# dump setup
with open(os.path.join(args.output, setup_name), 'w') as f:
yaml.dump(setup, f)
ref = None
if args.testdata is not None:
testsignal, features, periods, numbits = load_inference_data(args.testdata, **setup['data'])
inference_test = True
inference_folder = os.path.join(args.output, 'inference_test')
os.makedirs(os.path.join(args.output, 'inference_test'), exist_ok=True)
try:
ref = np.fromfile(os.path.join(args.testdata, 'clean.s16'), dtype=np.int16)
except:
pass
else:
inference_test = False
# training parameters
batch_size = setup['training']['batch_size']
epochs = setup['training']['epochs']
lr = setup['training']['lr']
lr_decay_factor = setup['training']['lr_decay_factor']
lr_gen = lr * setup['training']['gen_lr_reduction']
lambda_feat = setup['training']['lambda_feat']
lambda_reg = setup['training']['lambda_reg']
adv_target = setup['training'].get('adv_target', 'target')
# load training dataset
data_config = setup['data']
data = SilkEnhancementSet(setup['dataset'], **data_config)
# load validation dataset if given
if 'validation_dataset' in setup:
validation_data = SilkEnhancementSet(setup['validation_dataset'], **data_config)
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)
run_validation = True
else:
run_validation = False
# create model
model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
# create discriminator
disc_name = setup['discriminator']['name']
disc = model_dict[disc_name](
*setup['discriminator']['args'], **setup['discriminator']['kwargs']
)
# set compute device
if type(args.device) == type(None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
# dataloader
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
# optimizer is introduced to trainable parameters
parameters = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(parameters, lr=lr_gen)
# disc optimizer
parameters = [p for p in disc.parameters() if p.requires_grad]
optimizer_disc = torch.optim.Adam(parameters, lr=lr, betas=[0.5, 0.9])
# learning rate scheduler
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
if args.initial_checkpoint is not None:
print(f"loading state dict from {args.initial_checkpoint}...")
chkpt = torch.load(args.initial_checkpoint, map_location=device)
model.load_state_dict(chkpt['state_dict'])
if 'disc_state_dict' in chkpt:
print(f"loading discriminator state dict from {args.initial_checkpoint}...")
disc.load_state_dict(chkpt['disc_state_dict'])
if 'optimizer_state_dict' in chkpt:
print(f"loading optimizer state dict from {args.initial_checkpoint}...")
optimizer.load_state_dict(chkpt['optimizer_state_dict'])
if 'disc_optimizer_state_dict' in chkpt:
print(f"loading discriminator optimizer state dict from {args.initial_checkpoint}...")
optimizer_disc.load_state_dict(chkpt['disc_optimizer_state_dict'])
if 'scheduler_state_disc' in chkpt:
print(f"loading scheduler state dict from {args.initial_checkpoint}...")
scheduler.load_state_dict(chkpt['scheduler_state_dict'])
# if 'torch_rng_state' in chkpt:
# print(f"setting torch RNG state from {args.initial_checkpoint}...")
# torch.set_rng_state(chkpt['torch_rng_state'])
if 'numpy_rng_state' in chkpt:
print(f"setting numpy RNG state from {args.initial_checkpoint}...")
np.random.set_state(chkpt['numpy_rng_state'])
if 'python_rng_state' in chkpt:
print(f"setting Python RNG state from {args.initial_checkpoint}...")
random.setstate(chkpt['python_rng_state'])
# loss
w_l1 = setup['training']['loss']['w_l1']
w_lm = setup['training']['loss']['w_lm']
w_slm = setup['training']['loss']['w_slm']
w_sc = setup['training']['loss']['w_sc']
w_logmel = setup['training']['loss']['w_logmel']
w_wsc = setup['training']['loss']['w_wsc']
w_xcorr = setup['training']['loss']['w_xcorr']
w_sxcorr = setup['training']['loss']['w_sxcorr']
w_l2 = setup['training']['loss']['w_l2']
w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
logmelloss = MRLogMelLoss().to(device)
def xcorr_loss(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
return torch.mean(loss)
def td_l2_norm(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
return loss.mean()
def td_l1(y_true, y_pred, pow=0):
dims = list(range(1, len(y_true.shape)))
tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
return torch.mean(tmp)
def criterion(x, y):
return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
# model checkpoint
checkpoint = {
'setup' : setup,
'state_dict' : model.state_dict(),
'loss' : -1
}
if not args.no_redirect:
print(f"re-directing output to {os.path.join(args.output, output_file)}")
sys.stdout = open(os.path.join(args.output, output_file), "w")
print("summary:")
print(f"generator: {count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
if hasattr(model, 'flop_count'):
print(f"generator: {model.flop_count(16000) / 1e6:5.3f} MFLOPS")
print(f"discriminator: {count_parameters(disc.cpu()) / 1e6:5.3f} M parameters")
if ref is not None:
noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16)
initial_mos = pesq.pesq(16000, ref, noisy, mode='wb')
print(f"initial MOS (PESQ): {initial_mos}")
best_loss = 1e9
log_interval = 10
m_r = 0
m_f = 0
s_r = 1
s_f = 1
def optimizer_to(optim, device):
for param in optim.state.values():
if isinstance(param, torch.Tensor):
param.data = param.data.to(device)
if param._grad is not None:
param._grad.data = param._grad.data.to(device)
elif isinstance(param, dict):
for subparam in param.values():
if isinstance(subparam, torch.Tensor):
subparam.data = subparam.data.to(device)
if subparam._grad is not None:
subparam._grad.data = subparam._grad.data.to(device)
optimizer_to(optimizer, device)
optimizer_to(optimizer_disc, device)
retain_grads(model)
retain_grads(disc)
for ep in range(1, epochs + 1):
print(f"training epoch {ep}...")
model.to(device)
disc.to(device)
model.train()
disc.train()
running_disc_loss = 0
running_adv_loss = 0
running_feature_loss = 0
running_reg_loss = 0
running_disc_grad_norm = 0
running_model_grad_norm = 0
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# set gradients to zero
optimizer.zero_grad()
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target'].to(device)
disc_target = batch[adv_target].to(device)
# calculate model output
output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
# discriminator update
scores_gen = disc(output.detach())
scores_real = disc(disc_target.unsqueeze(1))
disc_loss = 0
for score in scores_gen:
disc_loss += (((score[-1]) ** 2)).mean()
m_f = 0.9 * m_f + 0.1 * score[-1].detach().mean().cpu().item()
s_f = 0.9 * s_f + 0.1 * score[-1].detach().std().cpu().item()
for score in scores_real:
disc_loss += (((1 - score[-1]) ** 2)).mean()
m_r = 0.9 * m_r + 0.1 * score[-1].detach().mean().cpu().item()
s_r = 0.9 * s_r + 0.1 * score[-1].detach().std().cpu().item()
disc_loss = 0.5 * disc_loss / len(scores_gen)
winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) )
disc.zero_grad()
disc_loss.backward()
running_disc_grad_norm += get_grad_norm(disc).detach().cpu().item()
optimizer_disc.step()
# generator update
scores_gen = disc(output)
# calculate loss
loss_reg = criterion(output.squeeze(1), target)
num_discs = len(scores_gen)
gen_loss = 0
for score in scores_gen:
gen_loss += (((1 - score[-1]) ** 2)).mean() / num_discs
loss_feat = 0
for k in range(num_discs):
num_layers = len(scores_gen[k]) - 1
f = 4 / num_discs / num_layers
for l in range(num_layers):
loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach())
model.zero_grad()
(gen_loss + lambda_feat * loss_feat + lambda_reg * loss_reg).backward()
optimizer.step()
running_model_grad_norm += get_grad_norm(model).detach().cpu().item()
running_adv_loss += gen_loss.detach().cpu().item()
running_disc_loss += disc_loss.detach().cpu().item()
running_feature_loss += lambda_feat * loss_feat.detach().cpu().item()
running_reg_loss += lambda_reg * loss_reg.detach().cpu().item()
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(adv_loss=f"{running_adv_loss/(i + 1):8.7f}",
disc_loss=f"{running_disc_loss/(i + 1):8.7f}",
feat_loss=f"{running_feature_loss/(i + 1):8.7f}",
reg_loss=f"{running_reg_loss/(i + 1):8.7f}",
model_gradnorm=f"{running_model_grad_norm/(i+1):8.7f}",
disc_gradnorm=f"{running_disc_grad_norm/(i+1):8.7f}",
wc=f"{100*winning_chance:5.2f}%")
# save checkpoint
checkpoint['state_dict'] = model.state_dict()
checkpoint['disc_state_dict'] = disc.state_dict()
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
checkpoint['disc_optimizer_state_dict'] = optimizer_disc.state_dict()
checkpoint['scheduler_state_dict'] = scheduler.state_dict()
checkpoint['torch_rng_state'] = torch.get_rng_state()
checkpoint['numpy_rng_state'] = np.random.get_state()
checkpoint['python_rng_state'] = random.getstate()
checkpoint['adv_loss'] = running_adv_loss/(i + 1)
checkpoint['disc_loss'] = running_disc_loss/(i + 1)
checkpoint['feature_loss'] = running_feature_loss/(i + 1)
checkpoint['reg_loss'] = running_reg_loss/(i + 1)
if inference_test:
print("running inference test...")
out = model.process(testsignal, features, periods, numbits).cpu().numpy()
wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
if ref is not None:
mos = pesq.pesq(16000, ref, out, mode='wb')
print(f"MOS (PESQ): {mos}")
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
print()
print('Done')

View file

@ -0,0 +1,451 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import sys
import math as m
import random
import yaml
from tqdm import tqdm
try:
import git
has_git = True
except:
has_git = False
import torch
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F
from scipy.io import wavfile
import numpy as np
import pesq
from data import LPCNetVocodingDataset
from models import model_dict
from utils.lpcnet_features import load_lpcnet_features
from utils.misc import count_parameters
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
parser = argparse.ArgumentParser()
parser.add_argument('setup', type=str, help='setup yaml file')
parser.add_argument('output', type=str, help='output path')
parser.add_argument('--device', type=str, help='compute device', default=None)
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
parser.add_argument('--test-features', type=str, help='path to features for testing', default=None)
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
args = parser.parse_args()
torch.set_num_threads(4)
with open(args.setup, 'r') as f:
setup = yaml.load(f.read(), yaml.FullLoader)
checkpoint_prefix = 'checkpoint'
output_prefix = 'output'
setup_name = 'setup.yml'
output_file='out.txt'
# check model
if not 'name' in setup['model']:
print(f'warning: did not find model entry in setup, using default PitchPostFilter')
model_name = 'pitchpostfilter'
else:
model_name = setup['model']['name']
# prepare output folder
if os.path.exists(args.output):
print("warning: output folder exists")
reply = input('continue? (y/n): ')
while reply not in {'y', 'n'}:
reply = input('continue? (y/n): ')
if reply == 'n':
os._exit()
else:
os.makedirs(args.output, exist_ok=True)
checkpoint_dir = os.path.join(args.output, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
# add repo info to setup
if has_git:
working_dir = os.path.split(__file__)[0]
try:
repo = git.Repo(working_dir)
setup['repo'] = dict()
hash = repo.head.object.hexsha
urls = list(repo.remote().urls)
is_dirty = repo.is_dirty()
if is_dirty:
print("warning: repo is dirty")
setup['repo']['hash'] = hash
setup['repo']['urls'] = urls
setup['repo']['dirty'] = is_dirty
except:
has_git = False
# dump setup
with open(os.path.join(args.output, setup_name), 'w') as f:
yaml.dump(setup, f)
ref = None
# prepare inference test if wanted
inference_test = False
if type(args.test_features) != type(None):
test_features = load_lpcnet_features(args.test_features)
features = test_features['features']
periods = test_features['periods']
inference_folder = os.path.join(args.output, 'inference_test')
os.makedirs(inference_folder, exist_ok=True)
inference_test = True
# training parameters
batch_size = setup['training']['batch_size']
epochs = setup['training']['epochs']
lr = setup['training']['lr']
lr_decay_factor = setup['training']['lr_decay_factor']
lr_gen = lr * setup['training']['gen_lr_reduction']
lambda_feat = setup['training']['lambda_feat']
lambda_reg = setup['training']['lambda_reg']
adv_target = setup['training'].get('adv_target', 'target')
# load training dataset
data_config = setup['data']
data = LPCNetVocodingDataset(setup['dataset'], **data_config)
# load validation dataset if given
if 'validation_dataset' in setup:
validation_data = LPCNetVocodingDataset(setup['validation_dataset'], **data_config)
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)
run_validation = True
else:
run_validation = False
# create model
model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
# create discriminator
disc_name = setup['discriminator']['name']
disc = model_dict[disc_name](
*setup['discriminator']['args'], **setup['discriminator']['kwargs']
)
# set compute device
if type(args.device) == type(None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
# dataloader
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
# optimizer is introduced to trainable parameters
parameters = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(parameters, lr=lr_gen)
# disc optimizer
parameters = [p for p in disc.parameters() if p.requires_grad]
optimizer_disc = torch.optim.Adam(parameters, lr=lr, betas=[0.5, 0.9])
# learning rate scheduler
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
if args.initial_checkpoint is not None:
print(f"loading state dict from {args.initial_checkpoint}...")
chkpt = torch.load(args.initial_checkpoint, map_location=device)
model.load_state_dict(chkpt['state_dict'])
if 'disc_state_dict' in chkpt:
print(f"loading discriminator state dict from {args.initial_checkpoint}...")
disc.load_state_dict(chkpt['disc_state_dict'])
if 'optimizer_state_dict' in chkpt:
print(f"loading optimizer state dict from {args.initial_checkpoint}...")
optimizer.load_state_dict(chkpt['optimizer_state_dict'])
if 'disc_optimizer_state_dict' in chkpt:
print(f"loading discriminator optimizer state dict from {args.initial_checkpoint}...")
optimizer_disc.load_state_dict(chkpt['disc_optimizer_state_dict'])
if 'scheduler_state_disc' in chkpt:
print(f"loading scheduler state dict from {args.initial_checkpoint}...")
scheduler.load_state_dict(chkpt['scheduler_state_dict'])
# if 'torch_rng_state' in chkpt:
# print(f"setting torch RNG state from {args.initial_checkpoint}...")
# torch.set_rng_state(chkpt['torch_rng_state'])
if 'numpy_rng_state' in chkpt:
print(f"setting numpy RNG state from {args.initial_checkpoint}...")
np.random.set_state(chkpt['numpy_rng_state'])
if 'python_rng_state' in chkpt:
print(f"setting Python RNG state from {args.initial_checkpoint}...")
random.setstate(chkpt['python_rng_state'])
# loss
w_l1 = setup['training']['loss']['w_l1']
w_lm = setup['training']['loss']['w_lm']
w_slm = setup['training']['loss']['w_slm']
w_sc = setup['training']['loss']['w_sc']
w_logmel = setup['training']['loss']['w_logmel']
w_wsc = setup['training']['loss']['w_wsc']
w_xcorr = setup['training']['loss']['w_xcorr']
w_sxcorr = setup['training']['loss']['w_sxcorr']
w_l2 = setup['training']['loss']['w_l2']
w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
logmelloss = MRLogMelLoss().to(device)
def xcorr_loss(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
return torch.mean(loss)
def td_l2_norm(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
return loss.mean()
def td_l1(y_true, y_pred, pow=0):
dims = list(range(1, len(y_true.shape)))
tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
return torch.mean(tmp)
def criterion(x, y):
return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
# model checkpoint
checkpoint = {
'setup' : setup,
'state_dict' : model.state_dict(),
'loss' : -1
}
if not args.no_redirect:
print(f"re-directing output to {os.path.join(args.output, output_file)}")
sys.stdout = open(os.path.join(args.output, output_file), "w")
print("summary:")
print(f"generator: {count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
if hasattr(model, 'flop_count'):
print(f"generator: {model.flop_count(16000) / 1e6:5.3f} MFLOPS")
print(f"discriminator: {count_parameters(disc.cpu()) / 1e6:5.3f} M parameters")
if ref is not None:
noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16)
initial_mos = pesq.pesq(16000, ref, noisy, mode='wb')
print(f"initial MOS (PESQ): {initial_mos}")
best_loss = 1e9
log_interval = 10
m_r = 0
m_f = 0
s_r = 1
s_f = 1
def optimizer_to(optim, device):
for param in optim.state.values():
if isinstance(param, torch.Tensor):
param.data = param.data.to(device)
if param._grad is not None:
param._grad.data = param._grad.data.to(device)
elif isinstance(param, dict):
for subparam in param.values():
if isinstance(subparam, torch.Tensor):
subparam.data = subparam.data.to(device)
if subparam._grad is not None:
subparam._grad.data = subparam._grad.data.to(device)
optimizer_to(optimizer, device)
optimizer_to(optimizer_disc, device)
for ep in range(1, epochs + 1):
print(f"training epoch {ep}...")
model.to(device)
disc.to(device)
model.train()
disc.train()
running_disc_loss = 0
running_adv_loss = 0
running_feature_loss = 0
running_reg_loss = 0
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# set gradients to zero
optimizer.zero_grad()
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target'].to(device)
disc_target = batch[adv_target].to(device)
# calculate model output
output = model(batch['features'], batch['periods'])
# discriminator update
scores_gen = disc(output.detach())
scores_real = disc(disc_target.unsqueeze(1))
disc_loss = 0
for scale in scores_gen:
disc_loss += ((scale[-1]) ** 2).mean()
m_f = 0.9 * m_f + 0.1 * scale[-1].detach().mean().cpu().item()
s_f = 0.9 * s_f + 0.1 * scale[-1].detach().std().cpu().item()
for scale in scores_real:
disc_loss += ((1 - scale[-1]) ** 2).mean()
m_r = 0.9 * m_r + 0.1 * scale[-1].detach().mean().cpu().item()
s_r = 0.9 * s_r + 0.1 * scale[-1].detach().std().cpu().item()
disc_loss = 0.5 * disc_loss / len(scores_gen)
winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) )
disc.zero_grad()
disc_loss.backward()
optimizer_disc.step()
# generator update
scores_gen = disc(output)
# calculate loss
loss_reg = criterion(output.squeeze(1), target)
num_discs = len(scores_gen)
loss_gen = 0
for scale in scores_gen:
loss_gen += ((1 - scale[-1]) ** 2).mean() / num_discs
loss_feat = 0
for k in range(num_discs):
num_layers = len(scores_gen[k]) - 1
f = 4 / num_discs / num_layers
for l in range(num_layers):
loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach())
model.zero_grad()
(loss_gen + lambda_feat * loss_feat + lambda_reg * loss_reg).backward()
optimizer.step()
running_adv_loss += loss_gen.detach().cpu().item()
running_disc_loss += disc_loss.detach().cpu().item()
running_feature_loss += lambda_feat * loss_feat.detach().cpu().item()
running_reg_loss += lambda_reg * loss_reg.detach().cpu().item()
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(adv_loss=f"{running_adv_loss/(i + 1):8.7f}",
disc_loss=f"{running_disc_loss/(i + 1):8.7f}",
feat_loss=f"{running_feature_loss/(i + 1):8.7f}",
reg_loss=f"{running_reg_loss/(i + 1):8.7f}",
wc=f"{100*winning_chance:5.2f}%")
# save checkpoint
checkpoint['state_dict'] = model.state_dict()
checkpoint['disc_state_dict'] = disc.state_dict()
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
checkpoint['disc_optimizer_state_dict'] = optimizer_disc.state_dict()
checkpoint['scheduler_state_dict'] = scheduler.state_dict()
checkpoint['torch_rng_state'] = torch.get_rng_state()
checkpoint['numpy_rng_state'] = np.random.get_state()
checkpoint['python_rng_state'] = random.getstate()
checkpoint['adv_loss'] = running_adv_loss/(i + 1)
checkpoint['disc_loss'] = running_disc_loss/(i + 1)
checkpoint['feature_loss'] = running_feature_loss/(i + 1)
checkpoint['reg_loss'] = running_reg_loss/(i + 1)
if inference_test:
print("running inference test...")
out = model.process(features, periods).cpu().numpy()
wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
if ref is not None:
mos = pesq.pesq(16000, ref, out, mode='wb')
print(f"MOS (PESQ): {mos}")
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
print()
print('Done')

View file

@ -1,30 +1,2 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
from .silk_enhancement_set import SilkEnhancementSet from .silk_enhancement_set import SilkEnhancementSet
from .lpcnet_vocoding_dataset import LPCNetVocodingDataset

View file

@ -0,0 +1,225 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
""" Dataset for LPCNet training """
import os
import yaml
import torch
import numpy as np
from torch.utils.data import Dataset
scale = 255.0/32768.0
scale_1 = 32768.0/255.0
def ulaw2lin(u):
u = u - 128
s = np.sign(u)
u = np.abs(u)
return s*scale_1*(np.exp(u/128.*np.log(256))-1)
def lin2ulaw(x):
s = np.sign(x)
x = np.abs(x)
u = (s*(128*np.log(1+scale*x)/np.log(256)))
u = np.clip(128 + np.round(u), 0, 255)
return u
def run_lpc(signal, lpcs, frame_length=160):
num_frames, lpc_order = lpcs.shape
prediction = np.concatenate(
[- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]
)
error = signal[lpc_order :] - prediction
return prediction, error
class LPCNetVocodingDataset(Dataset):
def __init__(self,
path_to_dataset,
features=['cepstrum', 'periods', 'pitch_corr'],
target='signal',
frames_per_sample=100,
feature_history=0,
feature_lookahead=0,
lpc_gamma=1):
super().__init__()
# load dataset info
self.path_to_dataset = path_to_dataset
with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f:
dataset = yaml.load(f, yaml.FullLoader)
# dataset version
self.version = dataset['version']
if self.version == 1:
self.getitem = self.getitem_v1
elif self.version == 2:
self.getitem = self.getitem_v2
else:
raise ValueError(f"dataset version {self.version} unknown")
# features
self.feature_history = feature_history
self.feature_lookahead = feature_lookahead
self.frame_offset = 2 + self.feature_history
self.frames_per_sample = frames_per_sample
self.input_features = features
self.feature_frame_layout = dataset['feature_frame_layout']
self.lpc_gamma = lpc_gamma
# load feature file
self.feature_file = os.path.join(path_to_dataset, dataset['feature_file'])
self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype'])
self.feature_frame_length = dataset['feature_frame_length']
assert len(self.features) % self.feature_frame_length == 0
self.features = self.features.reshape((-1, self.feature_frame_length))
# derive number of samples is dataset
self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1 - 2) // self.frames_per_sample
# signals
self.frame_length = dataset['frame_length']
self.signal_frame_layout = dataset['signal_frame_layout']
self.target = target
# load signals
self.signal_file = os.path.join(path_to_dataset, dataset['signal_file'])
self.signals = np.memmap(self.signal_file, dtype=dataset['signal_dtype'])
self.signal_frame_length = dataset['signal_frame_length']
self.signals = self.signals.reshape((-1, self.signal_frame_length))
assert len(self.signals) == len(self.features) * self.frame_length
def __getitem__(self, index):
return self.getitem(index)
def getitem_v2(self, index):
sample = dict()
# extract features
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
for feature in self.input_features:
feature_start, feature_stop = self.feature_frame_layout[feature]
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
# convert periods
if 'periods' in self.input_features:
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
# last_signal and signal are always expected to be there
sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']]
# calculate prediction and error if lpc coefficients present and prediction not given
if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout:
# lpc coefficients with one frame lookahead
# frame positions (start one frame early for past excitation)
frame_start = self.frame_offset + self.frames_per_sample * index - 1
frame_stop = self.frame_offset + self.frames_per_sample * (index + 1)
# feature positions
lpc_start, lpc_stop = self.feature_frame_layout['lpc']
lpc_order = lpc_stop - lpc_start
lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop]
# LPC weighting
lpc_order = lpc_stop - lpc_start
weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)])
lpcs = lpcs * weights
# signal position (lpc_order samples as history)
signal_start = frame_start * self.frame_length - lpc_order + 1
signal_stop = frame_stop * self.frame_length + 1
noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']]
noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length)
# extract signals
offset = self.frame_length
sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample]
sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample]
# calculate error between real signal and noisy prediction
sample['error'] = sample['signal'] - sample['prediction']
# concatenate features
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
target = torch.FloatTensor(sample[self.target]) / 2**15
periods = torch.LongTensor(sample['periods'])
return {'features' : features, 'periods' : periods, 'target' : target}
def getitem_v1(self, index):
sample = dict()
# extract features
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
for feature in self.input_features:
feature_start, feature_stop = self.feature_frame_layout[feature]
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
# convert periods
if 'periods' in self.input_features:
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
# last_signal and signal are always expected to be there
for signal_name, index in self.signal_frame_layout.items():
sample[signal_name] = self.signals[signal_start : signal_stop, index]
# concatenate features
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1)
target = torch.LongTensor(sample[self.target])
periods = torch.LongTensor(sample['periods'])
return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
def __len__(self):
return self.dataset_length

View file

@ -50,7 +50,7 @@ class SilkEnhancementSet(Dataset):
noisy_spec_scale='opus', noisy_spec_scale='opus',
noisy_apply_dct=True, noisy_apply_dct=True,
add_offset=False, add_offset=False,
add_double_lag_acorr=False add_double_lag_acorr=False,
): ):
assert frames_per_sample % 4 == 0 assert frames_per_sample % 4 == 0
@ -75,7 +75,8 @@ class SilkEnhancementSet(Dataset):
self.num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32) self.num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32)
self.offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32) self.offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32)
self.clean_signal = np.fromfile(os.path.join(path, 'clean_hp.s16'), dtype=np.int16) self.clean_signal_hp = np.fromfile(os.path.join(path, 'clean_hp.s16'), dtype=np.int16)
self.clean_signal = np.fromfile(os.path.join(path, 'clean.s16'), dtype=np.int16)
self.coded_signal = np.fromfile(os.path.join(path, 'coded.s16'), dtype=np.int16) self.coded_signal = np.fromfile(os.path.join(path, 'coded.s16'), dtype=np.int16)
self.create_features = silk_feature_factory(no_pitch_value, self.create_features = silk_feature_factory(no_pitch_value,
@ -92,7 +93,7 @@ class SilkEnhancementSet(Dataset):
# discard some frames to have enough signal history # discard some frames to have enough signal history
self.skip_frames = 4 * ((skip + self.history_len + 319) // 320 + 2) self.skip_frames = 4 * ((skip + self.history_len + 319) // 320 + 2)
num_frames = self.clean_signal.shape[0] // 80 - self.skip_frames num_frames = self.clean_signal_hp.shape[0] // 80 - self.skip_frames
self.len = num_frames // frames_per_sample self.len = num_frames // frames_per_sample
@ -107,6 +108,7 @@ class SilkEnhancementSet(Dataset):
signal_start = frame_start * self.frame_size - self.skip signal_start = frame_start * self.frame_size - self.skip
signal_stop = frame_stop * self.frame_size - self.skip signal_stop = frame_stop * self.frame_size - self.skip
clean_signal_hp = self.clean_signal_hp[signal_start : signal_stop].astype(np.float32) / 2**15
clean_signal = self.clean_signal[signal_start : signal_stop].astype(np.float32) / 2**15 clean_signal = self.clean_signal[signal_start : signal_stop].astype(np.float32) / 2**15
coded_signal = self.coded_signal[signal_start : signal_stop].astype(np.float32) / 2**15 coded_signal = self.coded_signal[signal_start : signal_stop].astype(np.float32) / 2**15
@ -124,6 +126,7 @@ class SilkEnhancementSet(Dataset):
if self.preemph > 0: if self.preemph > 0:
clean_signal[1:] -= self.preemph * clean_signal[: -1] clean_signal[1:] -= self.preemph * clean_signal[: -1]
clean_signal_hp[1:] -= self.preemph * clean_signal_hp[: -1]
coded_signal[1:] -= self.preemph * coded_signal[: -1] coded_signal[1:] -= self.preemph * coded_signal[: -1]
num_bits = np.repeat(self.num_bits[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1) num_bits = np.repeat(self.num_bits[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
@ -134,7 +137,8 @@ class SilkEnhancementSet(Dataset):
return { return {
'features' : features, 'features' : features,
'periods' : periods.astype(np.int64), 'periods' : periods.astype(np.int64),
'target' : clean_signal.astype(np.float32), 'target_orig' : clean_signal.astype(np.float32),
'target' : clean_signal_hp.astype(np.float32),
'signals' : coded_signal.reshape(-1, 1).astype(np.float32), 'signals' : coded_signal.reshape(-1, 1).astype(np.float32),
'numbits' : numbits.astype(np.float32) 'numbits' : numbits.astype(np.float32)
} }

View file

@ -0,0 +1,101 @@
import torch
from tqdm import tqdm
import sys
def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
model.to(device)
model.train()
running_loss = 0
previous_running_loss = 0
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# set gradients to zero
optimizer.zero_grad()
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target']
# calculate model output
output = model(batch['features'], batch['periods'])
# calculate loss
if isinstance(output, list):
loss = torch.zeros(1, device=device)
for y in output:
loss = loss + criterion(target, y.squeeze(1))
loss = loss / len(output)
else:
loss = criterion(target, output.squeeze(1))
# calculate gradients
loss.backward()
# update weights
optimizer.step()
# update learning rate
scheduler.step()
# update running loss
running_loss += float(loss.cpu())
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
previous_running_loss = running_loss
running_loss /= len(dataloader)
return running_loss
def evaluate(model, criterion, dataloader, device, log_interval=10):
model.to(device)
model.eval()
running_loss = 0
previous_running_loss = 0
with torch.no_grad():
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target']
# calculate model output
output = model(batch['features'], batch['periods'])
# calculate loss
loss = criterion(target, output.squeeze(1))
# update running loss
running_loss += float(loss.cpu())
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
previous_running_loss = running_loss
running_loss /= len(dataloader)
return running_loss

View file

@ -27,6 +27,36 @@
*/ */
""" """
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import sys
import argparse import argparse
import yaml import yaml
@ -36,12 +66,19 @@ from utils.templates import setup_dict
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('name', type=str, help='name of default setup file') parser.add_argument('name', type=str, help='name of default setup file')
parser.add_argument('--model', choices=['lace', 'nolace'], help='model name', default='lace') parser.add_argument('--model', choices=['lace', 'nolace', 'lavoce'], help='model name', default='lace')
parser.add_argument('--adversarial', action='store_true', help='setup for adversarial training')
parser.add_argument('--path2dataset', type=str, help='dataset path', default=None) parser.add_argument('--path2dataset', type=str, help='dataset path', default=None)
args = parser.parse_args() args = parser.parse_args()
setup = setup_dict[args.model] key = args.model + "_adv" if args.adversarial else args.model
try:
setup = setup_dict[key]
except KeyError:
print("setup not found, adversarial training possibly not specified for model")
sys.exit(1)
# update dataset if given # update dataset if given
if type(args.path2dataset) != type(None): if type(args.path2dataset) != type(None):

View file

@ -29,10 +29,12 @@
from .lace import LACE from .lace import LACE
from .no_lace import NoLACE from .no_lace import NoLACE
from .lavoce import LaVoce
from .fd_discriminator import TFDMultiResolutionDiscriminator as FDMResDisc
model_dict = { model_dict = {
'lace': LACE, 'lace': LACE,
'nolace': NoLACE 'nolace': NoLACE,
'lavoce': LaVoce,
'fdmresdisc': FDMResDisc,
} }

View file

@ -0,0 +1,974 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import math as m
import copy
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils import weight_norm, spectral_norm
import torchaudio
from utils.spec import gen_filterbank
# auxiliary functions
def remove_all_weight_norms(module):
for m in module.modules():
if hasattr(m, 'weight_v'):
nn.utils.remove_weight_norm(m)
def create_smoothing_kernel(h, w, gamma=1.5):
ch = h / 2 - 0.5
cw = w / 2 - 0.5
sh = gamma * ch
sw = gamma * cw
vx = ((torch.arange(h) - ch) / sh) ** 2
vy = ((torch.arange(w) - cw) / sw) ** 2
vals = vx.view(-1, 1) + vy.view(1, -1)
kernel = torch.exp(- vals)
kernel = kernel / kernel.sum()
return kernel
def create_kernel(h, w, sh, sw):
# proto kernel gives disjoint partition of 1
proto_kernel = torch.ones((sh, sw))
# create smoothing kernel eta
h_eta, w_eta = h - sh + 1, w - sw + 1
assert h_eta > 0 and w_eta > 0
eta = create_smoothing_kernel(h_eta, w_eta).view(1, 1, h_eta, w_eta)
kernel0 = F.pad(proto_kernel, [w_eta - 1, w_eta - 1, h_eta - 1, h_eta - 1]).unsqueeze(0).unsqueeze(0)
kernel = F.conv2d(kernel0, eta)
return kernel
# positional embeddings
class FrequencyPositionalEmbedding(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
N = x.size(2)
args = torch.arange(0, N, dtype=x.dtype, device=x.device) * torch.pi * 2 / N
cos = torch.cos(args).reshape(1, 1, -1, 1)
sin = torch.sin(args).reshape(1, 1, -1, 1)
zeros = torch.zeros_like(x[:, 0:1, :, :])
y = torch.cat((x, zeros + sin, zeros + cos), dim=1)
return y
class PositionalEmbedding2D(nn.Module):
def __init__(self, d=5):
super().__init__()
self.d = d
def forward(self, x):
N = x.size(2)
M = x.size(3)
h_args = torch.arange(0, N, dtype=x.dtype, device=x.device).reshape(1, 1, -1, 1)
w_args = torch.arange(0, M, dtype=x.dtype, device=x.device).reshape(1, 1, 1, -1)
coeffs = (10000 ** (-2 * torch.arange(0, self.d, dtype=x.dtype, device=x.device) / self.d)).reshape(1, -1, 1, 1)
h_sin = torch.sin(coeffs * h_args)
h_cos = torch.sin(coeffs * h_args)
w_sin = torch.sin(coeffs * w_args)
w_cos = torch.sin(coeffs * w_args)
zeros = torch.zeros_like(x[:, 0:1, :, :])
y = torch.cat((x, zeros + h_sin, zeros + h_cos, zeros + w_sin, zeros + w_cos), dim=1)
return y
# spectral discriminator base class
class SpecDiscriminatorBase(nn.Module):
RECEPTIVE_FIELD_MAX_WIDTH=10000
def __init__(self,
layers,
resolution,
fs=16000,
freq_roi=[50, 7000],
noise_gain=1e-3,
fmap_start_index=0
):
super().__init__()
self.layers = nn.ModuleList(layers)
self.resolution = resolution
self.fs = fs
self.noise_gain = noise_gain
self.fmap_start_index = fmap_start_index
if fmap_start_index >= len(layers):
raise ValueError(f'fmap_start_index is larger than number of layers')
# filter bank for noise shaping
n_fft = resolution[0]
self.filterbank = nn.Parameter(
gen_filterbank(n_fft // 2, fs, keep_size=True),
requires_grad=False
)
# roi bins
f_step = fs / n_fft
self.start_bin = int(m.ceil(freq_roi[0] / f_step - 0.01))
self.stop_bin = min(int(m.floor(freq_roi[1] / f_step + 0.01)), n_fft//2 + 1)
self.init_weights()
# determine receptive field size, offsets and strides
hw = 1000
while True:
x = torch.zeros((1, hw, hw))
with torch.no_grad():
y = self.run_layer_stack(x)[-1]
pos0 = [y.size(-2) // 2, y.size(-1) // 2]
pos1 = [t + 1 for t in pos0]
hs0, ws0 = self._receptive_field((hw, hw), pos0)
hs1, ws1 = self._receptive_field((hw, hw), pos1)
h0 = hs0[1] - hs0[0] + 1
h1 = hs1[1] - hs1[0] + 1
w0 = ws0[1] - ws0[0] + 1
w1 = ws1[1] - ws1[0] + 1
if h0 != h1 or w0 != w1:
hw = 2 * hw
else:
# strides
sh = hs1[0] - hs0[0]
sw = ws1[0] - ws0[0]
if sh == 0 or sw == 0: continue
# offsets
oh = hs0[0] - sh * pos0[0]
ow = ws0[0] - sw * pos0[1]
# overlap factor
overlap = w0 / sw + h0 / sh
#print(f"{w0=} {h0=} {sw=} {sh=} {overlap=}")
self.receptive_field_params = {'width': [sw, ow, w0], 'height': [sh, oh, h0], 'overlap': overlap}
break
if hw > self.RECEPTIVE_FIELD_MAX_WIDTH:
print("warning: exceeded max size while trying to determine receptive field")
# create transposed convolutional kernel
#self.tconv_kernel = nn.Parameter(create_kernel(h0, w0, sw, sw), requires_grad=False)
def run_layer_stack(self, spec):
output = []
x = spec.unsqueeze(1)
for layer in self.layers:
x = layer(x)
output.append(x)
return output
def forward(self, x):
""" returns array with feature maps and final score at index -1 """
output = []
x = self.spectrogram(x)
output = self.run_layer_stack(x)
return output[self.fmap_start_index:]
def receptive_field(self, output_pos):
if self.receptive_field_params is not None:
s, o, h = self.receptive_field_params['height']
h_min = output_pos[0] * s + o + self.start_bin
h_max = h_min + h
h_min = max(h_min, self.start_bin)
h_max = min(h_max, self.stop_bin)
s, o, w = self.receptive_field_params['width']
w_min = output_pos[1] * s + o
w_max = w_min + w
return (h_min, h_max), (w_min, w_max)
else:
return None, None
def _receptive_field(self, input_dims, output_pos):
""" determines receptive field probabilistically via autograd (slow) """
x = torch.randn((1,) + input_dims, requires_grad=True)
# run input through layers
y = self.run_layer_stack(x)[-1]
b, c, h, w = y.shape
if output_pos[0] >= h or output_pos[1] >= w:
raise ValueError("position out of range")
mask = torch.zeros((b, c, h, w))
mask[0, 0, output_pos[0], output_pos[1]] = 1
(mask * y).sum().backward()
hs, ws = torch.nonzero(x.grad[0], as_tuple=True)
h_min, h_max = hs.min().item(), hs.max().item()
w_min, w_max = ws.min().item(), ws.max().item()
return [h_min, h_max], [w_min, w_max]
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def spectrogram(self, x):
n_fft, hop_length, win_length = self.resolution
x = x.squeeze(1)
window = getattr(torch, 'hann_window')(win_length).to(x.device)
x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length,\
window=window, return_complex=True) #[B, F, T]
x = torch.abs(x)
# noise floor following spectral envelope
smoothed_x = torch.matmul(self.filterbank, x)
noise = torch.randn_like(x) * smoothed_x * self.noise_gain
x = x + noise
# frequency ROI
x = x[:, self.start_bin : self.stop_bin + 1, ...]
return torchaudio.functional.amplitude_to_DB(x,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)#torch.sqrt(x)
def grad_map(self, x):
self.zero_grad()
n_fft, hop_length, win_length = self.resolution
window = getattr(torch, 'hann_window')(win_length).to(x.device)
y = torch.stft(x.squeeze(1), n_fft=n_fft, hop_length=hop_length, win_length=win_length,
window=window, return_complex=True) #[B, F, T]
y = torch.abs(y)
specgram = torchaudio.functional.amplitude_to_DB(y,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)
specgram.requires_grad = True
specgram.retain_grad()
if specgram.grad is not None:
specgram.grad.zero_()
y = specgram[:, self.start_bin : self.stop_bin + 1, ...]
scores = self.run_layer_stack(y)[-1]
loss = torch.mean((1 - scores) ** 2)
loss.backward()
return specgram.data[0], torch.abs(specgram.grad)[0]
def relevance_map(self, x):
n_fft, hop_length, win_length = self.resolution
y = x.view(-1)
window = getattr(torch, 'hann_window')(win_length).to(x.device)
y = torch.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length,\
window=window, return_complex=True) #[B, F, T]
y = torch.abs(y)
specgram = torchaudio.functional.amplitude_to_DB(y,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)
scores = self.forward(x)[-1]
sh, _, h = self.receptive_field_params['height']
sw, _, w = self.receptive_field_params['width']
kernel = create_kernel(h, w, sh, sw).float().to(scores.device)
with torch.no_grad():
pad_w = (w + sw - 1) // sw
pad_h = (h + sh - 1) // sh
padded_scores = F.pad(scores, (pad_w, pad_w, pad_h, pad_h), mode='replicate')
# CAVE: padding should be derived from offsets
rv = F.conv_transpose2d(padded_scores, kernel, bias=None, stride=(sh, sw), padding=(h//2, w//2))
rv = rv[..., pad_h * sh : - pad_h * sh, pad_w * sw : -pad_w * sw]
relevance = torch.zeros_like(specgram)
relevance[..., self.start_bin : self.start_bin + rv.size(-2), : rv.size(-1)] = rv
return specgram, relevance
def lrp(self, x, eps=1e-9, label='both', threshold=0.5, low=None, high=None, verbose=False):
""" layer-wise relevance propagation (https://git.tu-berlin.de/gmontavon/lrp-tutorial) """
# ToDo: this code is highly unsafe as it assumes that layers are nn.Sequential with suitable activations
def newconv2d(layer,g):
new_layer = nn.Conv2d(layer.in_channels,
layer.out_channels,
layer.kernel_size,
stride=layer.stride,
padding=layer.padding,
dilation=layer.dilation,
groups=layer.groups)
try: new_layer.weight = nn.Parameter(g(layer.weight.data.clone()))
except AttributeError: pass
try: new_layer.bias = nn.Parameter(g(layer.bias.data.clone()))
except AttributeError: pass
return new_layer
bounds = {
64: [-85.82449722290039, 2.1755014657974243],
128: [-84.49211349487305, 3.5078893899917607],
256: [-80.33127822875977, 7.6687201976776125],
512: [-73.79328079223633, 14.20672025680542],
1024: [-67.59239501953125, 20.40760498046875],
2048: [-62.31902580261231, 25.680974197387698],
}
nfft = self.resolution[0]
if low is None: low = bounds[nfft][0]
if high is None: high = bounds[nfft][1]
remove_all_weight_norms(self)
for p in self.parameters():
if p.grad is not None:
p.grad.zero_()
num_layers = len(self.layers)
X = self.spectrogram(x). detach()
# forward pass
A = [X.unsqueeze(1)] + [None] * len(self.layers)
for i in range(num_layers - 1):
A[i + 1] = self.layers[i](A[i])
# initial relevance is last layer without activation
r = A[-2]
last_layer_rs = [r]
layer = self.layers[-1]
for sublayer in list(layer)[:-1]:
r = sublayer(r)
last_layer_rs.append(r)
mask = torch.zeros_like(r)
mask.requires_grad_(False)
if verbose:
print(r.min(), r.max())
if label in {'both', 'fake'}:
mask[r < -threshold] = 1
if label in {'both', 'real'}:
mask[r > threshold] = 1
r = r * mask
# backward pass
R = [None] * num_layers + [r]
for l in range(1, num_layers)[::-1]:
A[l] = (A[l]).data.requires_grad_(True)
layer = nn.Sequential(*(list(self.layers[l])[:-1]))
z = layer(A[l]) + eps
s = (R[l+1] / z).data
(z*s).sum().backward()
c = A[l].grad
R[l] = (A[l] * c).data
# first layer
A[0] = (A[0].data).requires_grad_(True)
Xl = (torch.zeros_like(A[0].data) + low).requires_grad_(True)
Xh = (torch.zeros_like(A[0].data) + high).requires_grad_(True)
if len(list(self.layers)) > 2:
# unsafe way to check for embedding layer
embed = list(self.layers[0])[0]
conv = list(self.layers[0])[1]
layer = nn.Sequential(embed, conv)
layerl = nn.Sequential(embed, newconv2d(conv, lambda p: p.clamp(min=0)))
layerh = nn.Sequential(embed, newconv2d(conv, lambda p: p.clamp(max=0)))
else:
layer = list(self.layers[0])[0]
layerl = newconv2d(layer, lambda p: p.clamp(min=0))
layerh = newconv2d(layer, lambda p: p.clamp(max=0))
z = layer(A[0])
z -= layerl(Xl) + layerh(Xh)
s = (R[1] / z).data
(z * s).sum().backward()
c, cp, cm = A[0].grad, Xl.grad, Xh.grad
R[0] = (A[0] * c + Xl * cp + Xh * cm)
#R[0] = (A[0] * c).data
return X, R[0].mean(dim=1)
def create_3x3_conv_plan(num_layers : int,
f_stretch : int,
f_down : int,
t_stretch : int,
t_down : int
):
""" creates a stride, dilation, padding plan for a 2d conv network
Args:
num_layers (int): number of layers
f_stretch (int): log_2 of stretching factor along frequency axis
f_down (int): log_2 of downsampling factor along frequency axis
t_stretch (int): log_2 of stretching factor along time axis
t_down (int): log_2 of downsampling factor along time axis
Returns:
list(list(tuple)): list containing entries [(stride_t, stride_f), (dilation_t, dilation_f), (padding_t, padding_f)]
"""
assert num_layers > 0 and t_stretch >= 0 and t_down >= 0 and f_stretch >= 0 and f_down >= 0
assert f_stretch < num_layers and t_stretch < num_layers
def process_dimension(n_layers, stretch, down):
stack_layers = n_layers - 1
stride_layers = min(min(down, stretch) , stack_layers)
dilation_layers = max(min(stack_layers - stride_layers - 1, stretch - stride_layers), 0)
final_stride = 2 ** (max(down - stride_layers, 0))
final_dilation = 1
if stride_layers < stack_layers and stretch - stride_layers - dilation_layers > 0:
final_dilation = 2
strides, dilations, paddings = [], [], []
processed_layers = 0
current_dilation = 1
for _ in range(stride_layers):
# increase receptive field and downsample via stride = 2
strides.append(2)
dilations.append(1)
paddings.append(1)
processed_layers += 1
if processed_layers < stack_layers:
strides.append(1)
dilations.append(1)
paddings.append(1)
processed_layers += 1
for _ in range(dilation_layers):
# increase receptive field via dilation = 2
strides.append(1)
current_dilation *= 2
dilations.append(current_dilation)
paddings.append(current_dilation)
processed_layers += 1
while processed_layers < n_layers - 1:
# fill up with std layers
strides.append(1)
dilations.append(current_dilation)
paddings.append(current_dilation)
processed_layers += 1
# final layer
strides.append(final_stride)
current_dilation * final_dilation
dilations.append(current_dilation)
paddings.append(current_dilation)
processed_layers += 1
assert processed_layers == n_layers
return strides, dilations, paddings
t_strides, t_dilations, t_paddings = process_dimension(num_layers, t_stretch, t_down)
f_strides, f_dilations, f_paddings = process_dimension(num_layers, f_stretch, f_down)
plan = []
for i in range(num_layers):
plan.append([
(f_strides[i], t_strides[i]),
(f_dilations[i], t_dilations[i]),
(f_paddings[i], t_paddings[i]),
])
return plan
class DiscriminatorExperimental(SpecDiscriminatorBase):
def __init__(self,
resolution,
fs=16000,
freq_roi=[50, 7400],
noise_gain=0,
num_channels=16,
max_channels=512,
num_layers=5,
use_spectral_norm=False):
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.num_channels = num_channels
self.num_channels_max = max_channels
self.num_layers = num_layers
layers = []
stride = (2, 1)
padding= (1, 1)
in_channels = 1 + 2
out_channels = self.num_channels
for _ in range(self.num_layers):
layers.append(
nn.Sequential(
FrequencyPositionalEmbedding(),
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
nn.ReLU(inplace=True)
)
)
in_channels = out_channels + 2
out_channels = min(2 * out_channels, self.num_channels_max)
layers.append(
nn.Sequential(
FrequencyPositionalEmbedding(),
norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding)),
nn.Sigmoid()
)
)
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
# bias biases
bias_val = 0.1
with torch.no_grad():
for name, weight in self.named_parameters():
if 'bias' in name:
weight = weight + bias_val
configs = {
'f_down': {
'stretch' : {
64 : (0, 0),
128: (1, 0),
256: (2, 0),
512: (3, 0),
1024: (4, 0),
2048: (5, 0)
},
'down' : {
64 : (0, 0),
128: (1, 0),
256: (2, 0),
512: (3, 0),
1024: (4, 0),
2048: (5, 0)
}
},
'ft_down': {
'stretch' : {
64 : (0, 4),
128: (1, 3),
256: (2, 2),
512: (3, 1),
1024: (4, 0),
2048: (5, 0)
},
'down' : {
64 : (0, 4),
128: (1, 3),
256: (2, 2),
512: (3, 1),
1024: (4, 0),
2048: (5, 0)
}
},
'dilated': {
'stretch' : {
64 : (0, 4),
128: (1, 3),
256: (2, 2),
512: (3, 1),
1024: (4, 0),
2048: (5, 0)
},
'down' : {
64 : (0, 0),
128: (0, 0),
256: (0, 0),
512: (0, 0),
1024: (0, 0),
2048: (0, 0)
}
},
'mixed': {
'stretch' : {
64 : (0, 4),
128: (1, 3),
256: (2, 2),
512: (3, 1),
1024: (4, 0),
2048: (5, 0)
},
'down' : {
64 : (0, 0),
128: (1, 0),
256: (2, 0),
512: (3, 0),
1024: (4, 0),
2048: (5, 0)
}
},
}
class DiscriminatorMagFree(SpecDiscriminatorBase):
def __init__(self,
resolution,
fs=16000,
freq_roi=[50, 7400],
noise_gain=0,
num_channels=16,
max_channels=256,
num_layers=5,
use_spectral_norm=False,
design=None):
if design is None:
raise ValueError('error: arch required in DiscriminatorMagFree')
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
stretch = configs[design]['stretch'][resolution[0]]
down = configs[design]['down'][resolution[0]]
self.num_channels = num_channels
self.num_channels_max = max_channels
self.num_layers = num_layers
self.stretch = stretch
self.down = down
layers = []
plan = create_3x3_conv_plan(num_layers + 1, stretch[0], down[0], stretch[1], down[1])
in_channels = 1 + 2
out_channels = self.num_channels
for i in range(self.num_layers):
layers.append(
nn.Sequential(
FrequencyPositionalEmbedding(),
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=plan[i][0], dilation=plan[i][1], padding=plan[i][2])),
nn.ReLU(inplace=True)
)
)
in_channels = out_channels + 2
# product over strides
channel_factor = plan[i][0][0] * plan[i][0][1]
out_channels = min(channel_factor * out_channels, self.num_channels_max)
layers.append(
nn.Sequential(
FrequencyPositionalEmbedding(),
norm_f(nn.Conv2d(in_channels, 1, (3, 3), stride=plan[-1][0], dilation=plan[-1][1], padding=plan[-1][2])),
nn.Sigmoid()
)
)
# for layer in layers:
# print(layer)
# print("end\n\n")
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
# bias biases
bias_val = 0.1
with torch.no_grad():
for name, weight in self.named_parameters():
if 'bias' in name:
weight = weight + bias_val
class DiscriminatorMagFreqPosition(SpecDiscriminatorBase):
def __init__(self,
resolution,
fs=16000,
freq_roi=[50, 7400],
noise_gain=0,
num_channels=16,
max_channels=512,
num_layers=5,
use_spectral_norm=False):
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.num_channels = num_channels
self.num_channels_max = max_channels
self.num_layers = num_layers
layers = []
stride = (2, 1)
padding= (1, 1)
in_channels = 1 + 2
out_channels = self.num_channels
for _ in range(self.num_layers):
layers.append(
nn.Sequential(
FrequencyPositionalEmbedding(),
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
nn.LeakyReLU(0.2, inplace=True)
)
)
in_channels = out_channels + 2
out_channels = min(2 * out_channels, self.num_channels_max)
layers.append(
nn.Sequential(
FrequencyPositionalEmbedding(),
norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding))
)
)
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
class DiscriminatorMag2dPositional(SpecDiscriminatorBase):
def __init__(self,
resolution,
fs=16000,
freq_roi=[50, 7400],
noise_gain=0,
num_channels=16,
max_channels=512,
num_layers=5,
d=5,
use_spectral_norm=False):
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.resolution = resolution
self.num_channels = num_channels
self.num_channels_max = max_channels
self.num_layers = num_layers
self.d = d
embedding_dim = 4 * d
layers = []
stride = (2, 2)
padding= (1, 1)
in_channels = 1 + embedding_dim
out_channels = self.num_channels
for _ in range(self.num_layers):
layers.append(
nn.Sequential(
PositionalEmbedding2D(d),
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
nn.LeakyReLU(0.2, inplace=True)
)
)
in_channels = out_channels + embedding_dim
out_channels = min(2 * out_channels, self.num_channels_max)
layers.append(
nn.Sequential(
PositionalEmbedding2D(),
norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding))
)
)
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
class DiscriminatorMag(SpecDiscriminatorBase):
def __init__(self,
resolution,
fs=16000,
freq_roi=[50, 7400],
noise_gain=0,
num_channels=32,
num_layers=5,
use_spectral_norm=False):
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.num_channels = num_channels
self.num_layers = num_layers
layers = []
stride = (1, 1)
padding= (1, 1)
in_channels = 1
out_channels = self.num_channels
for _ in range(self.num_layers):
layers.append(
nn.Sequential(
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
nn.LeakyReLU(0.2, inplace=True)
)
)
in_channels = out_channels
layers.append(norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding)))
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
discriminators = {
'mag': DiscriminatorMag,
'freqpos': DiscriminatorMagFreqPosition,
'2dpos': DiscriminatorMag2dPositional,
'experimental': DiscriminatorExperimental,
'free': DiscriminatorMagFree
}
class TFDMultiResolutionDiscriminator(torch.nn.Module):
def __init__(self,
fft_sizes_16k=[64, 128, 256, 512, 1024, 2048],
architecture='mag',
fs=16000,
freq_roi=[50, 7400],
noise_gain=0,
use_spectral_norm=False,
**kwargs):
super().__init__()
fft_sizes = [int(round(fft_size_16k * fs / 16000)) for fft_size_16k in fft_sizes_16k]
resolutions = [[n_fft, n_fft // 4, n_fft] for n_fft in fft_sizes]
Disc = discriminators[architecture]
discs = [Disc(resolutions[i], fs=fs, freq_roi=freq_roi, noise_gain=noise_gain, use_spectral_norm=use_spectral_norm, **kwargs) for i in range(len(resolutions))]
self.discriminators = nn.ModuleList(discs)
def forward(self, y):
outputs = []
for disc in self.discriminators:
outputs.append(disc(y))
return outputs
class FWGAN_disc_wrapper(nn.Module):
def __init__(self, disc):
super().__init__()
self.disc = disc
def forward(self, y, y_hat):
out_real = self.disc(y)
out_fake = self.disc(y_hat)
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for y_real, y_fake in zip(out_real, out_fake):
y_d_rs.append(y_real[-1])
y_d_gs.append(y_fake[-1])
fmap_rs.append(y_real[:-1])
fmap_gs.append(y_fake[:-1])
return y_d_rs, y_d_gs, fmap_rs, fmap_gs

View file

@ -0,0 +1,254 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
from utils.layers.td_shaper import TDShaper
from utils.layers.noise_shaper import NoiseShaper
from utils.complexity import _conv1d_flop_count
from utils.endoscopy import write_data
from models.nns_base import NNSBase
from models.lpcnet_feature_net import LPCNetFeatureNet
from .scale_embedding import ScaleEmbedding
class LaVoce(nn.Module):
""" Linear-Adaptive VOCodEr """
FEATURE_FRAME_SIZE=160
FRAME_SIZE=80
def __init__(self,
num_features=20,
pitch_embedding_dim=64,
cond_dim=256,
pitch_max=300,
kernel_size=15,
preemph=0.85,
comb_gain_limit_db=-6,
global_gain_limits_db=[-6, 6],
conv_gain_limits_db=[-6, 6],
norm_p=2,
avg_pool_k=4,
pulses=False):
super().__init__()
self.num_features = num_features
self.cond_dim = cond_dim
self.pitch_max = pitch_max
self.pitch_embedding_dim = pitch_embedding_dim
self.kernel_size = kernel_size
self.preemph = preemph
self.pulses = pulses
assert self.FEATURE_FRAME_SIZE % self.FRAME_SIZE == 0
self.upsamp_factor = self.FEATURE_FRAME_SIZE // self.FRAME_SIZE
# pitch embedding
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
# feature net
self.feature_net = LPCNetFeatureNet(num_features + pitch_embedding_dim, cond_dim, self.upsamp_factor)
# noise shaper
self.noise_shaper = NoiseShaper(cond_dim, self.FRAME_SIZE)
# comb filters
left_pad = self.kernel_size // 2
right_pad = self.kernel_size - 1 - left_pad
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
self.af_prescale = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af_mix = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
# spectral shaping
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
# non-linear transforms
self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=True)
self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k)
self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k)
# combinators
self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
# feature transforms
self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, 2)
self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, 2)
self.post_af1 = nn.Conv1d(cond_dim, cond_dim, 2)
self.post_af2 = nn.Conv1d(cond_dim, cond_dim, 2)
self.post_af3 = nn.Conv1d(cond_dim, cond_dim, 2)
def create_phase_signals(self, periods, pulses=False):
batch_size = periods.size(0)
progression = torch.arange(1, self.FRAME_SIZE + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
progression = torch.repeat_interleave(progression, batch_size, 0)
phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
chunks = []
for sframe in range(periods.size(1)):
f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
if pulses:
alpha = torch.cos(f)
chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
pulse_a = torch.relu(chunk_sin - alpha) / (1 - alpha)
pulse_b = torch.relu(-chunk_sin - alpha) / (1 - alpha)
chunk = torch.cat((pulse_a, pulse_b), dim = 1)
else:
chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
chunk_cos = torch.cos(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
chunk = torch.cat((chunk_sin, chunk_cos), dim = 1)
phase0 = phase0 + self.FRAME_SIZE * f
chunks.append(chunk)
phase_signals = torch.cat(chunks, dim=-1)
return phase_signals
def flop_count(self, rate=16000, verbose=False):
frame_rate = rate / self.FRAME_SIZE
# feature net
feature_net_flops = self.feature_net.flop_count(frame_rate)
comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate) + self.af_prescale.flop_count(rate) + self.af_mix.flop_count(rate)
feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate)
+ _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate))
if verbose:
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
print(f"comb filters: {comb_flops / 1e6} MFLOPS")
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
print(f"feature transforms: {feature_flops / 1e6} MFLOPS")
return feature_net_flops + comb_flops + af_flops + feature_flops
def feature_transform(self, f, layer):
f = f.permute(0, 2, 1)
f = F.pad(f, [1, 0])
f = torch.tanh(layer(f))
return f.permute(0, 2, 1)
def forward(self, features, periods, debug=False):
periods = periods.squeeze(-1)
pitch_embedding = self.pitch_embedding(periods)
full_features = torch.cat((features, pitch_embedding), dim=-1)
cf = self.feature_net(full_features)
# upsample periods
periods = torch.repeat_interleave(periods, self.upsamp_factor, 1)
# pre-net
ref_phase = torch.tanh(self.create_phase_signals(periods))
x = self.af_prescale(ref_phase, cf)
noise = self.noise_shaper(cf)
y = self.af_mix(torch.cat((x, noise), dim=1), cf)
if debug:
ch0 = y[0,0,:].detach().cpu().numpy()
ch1 = y[0,1,:].detach().cpu().numpy()
ch0 = (2**15 * ch0 / np.max(ch0)).astype(np.int16)
ch1 = (2**15 * ch1 / np.max(ch1)).astype(np.int16)
write_data('prior_channel0', ch0, 16000)
write_data('prior_channel1', ch1, 16000)
# temporal shaping + innovating
y1 = y[:, 0:1, :]
y2 = self.tdshape1(y[:, 1:2, :], cf)
y = torch.cat((y1, y2), dim=1)
y = self.af2(y, cf, debug=debug)
cf = self.feature_transform(cf, self.post_af2)
y1 = y[:, 0:1, :]
y2 = self.tdshape2(y[:, 1:2, :], cf)
y = torch.cat((y1, y2), dim=1)
y = self.af3(y, cf, debug=debug)
cf = self.feature_transform(cf, self.post_af3)
# spectral shaping
y = self.cf1(y, cf, periods, debug=debug)
cf = self.feature_transform(cf, self.post_cf1)
y = self.cf2(y, cf, periods, debug=debug)
cf = self.feature_transform(cf, self.post_cf2)
y = self.af1(y, cf, debug=debug)
cf = self.feature_transform(cf, self.post_af1)
# final temporal env adjustment
y1 = y[:, 0:1, :]
y2 = self.tdshape3(y[:, 1:2, :], cf)
y = torch.cat((y1, y2), dim=1)
y = self.af4(y, cf, debug=debug)
return y
def process(self, features, periods, debug=False):
self.eval()
device = next(iter(self.parameters())).device
with torch.no_grad():
# run model
f = features.unsqueeze(0).to(device)
p = periods.unsqueeze(0).to(device)
y = self.forward(f, p, debug=debug).squeeze()
# deemphasis
if self.preemph > 0:
for i in range(len(y) - 1):
y[i + 1] += self.preemph * y[i]
# clip to valid range
out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short()
return out

View file

@ -0,0 +1,91 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
from utils.complexity import _conv1d_flop_count
class LPCNetFeatureNet(nn.Module):
def __init__(self,
feature_dim=84,
num_channels=256,
upsamp_factor=2,
lookahead=True):
super().__init__()
self.feature_dim = feature_dim
self.num_channels = num_channels
self.upsamp_factor = upsamp_factor
self.lookahead = lookahead
self.conv1 = nn.Conv1d(feature_dim, num_channels, 3)
self.conv2 = nn.Conv1d(num_channels, num_channels, 3)
self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
self.tconv = nn.ConvTranspose1d(num_channels, num_channels, upsamp_factor, upsamp_factor)
def flop_count(self, rate=100):
count = 0
for conv in self.conv1, self.conv2, self.tconv:
count += _conv1d_flop_count(conv, rate)
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
return count
def forward(self, features, state=None):
""" features shape: (batch_size, num_frames, feature_dim) """
batch_size = features.size(0)
if state is None:
state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
features = features.permute(0, 2, 1)
if self.lookahead:
c = torch.tanh(self.conv1(F.pad(features, [1, 1])))
c = torch.tanh(self.conv2(F.pad(c, [2, 0])))
else:
c = torch.tanh(self.conv1(F.pad(features, [2, 0])))
c = torch.tanh(self.conv2(F.pad(c, [2, 0])))
c = torch.tanh(self.tconv(c))
c = c.permute(0, 2, 1)
c, _ = self.gru(c, state)
return c

View file

@ -1,3 +1,31 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch import torch
from torch import nn from torch import nn

View file

@ -0,0 +1,103 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import argparse
import torch
from scipy.io import wavfile
from time import time
from models import model_dict
from utils.lpcnet_features import load_lpcnet_features
from utils import endoscopy
debug = False
if debug:
args = type('dummy', (object,),
{
'input' : 'testitems/all_0_orig.se',
'checkpoint' : 'testout/checkpoints/checkpoint_epoch_5.pth',
'output' : 'out.wav',
})()
else:
parser = argparse.ArgumentParser()
parser.add_argument('input', type=str, help='path to input features')
parser.add_argument('checkpoint', type=str, help='checkpoint file')
parser.add_argument('output', type=str, help='output file')
parser.add_argument('--debug', action='store_true', help='enables debug output')
args = parser.parse_args()
torch.set_num_threads(2)
input_folder = args.input
checkpoint_file = args.checkpoint
output_file = args.output
if not output_file.endswith('.wav'):
output_file += '.wav'
checkpoint = torch.load(checkpoint_file, map_location="cpu")
# check model
if not 'name' in checkpoint['setup']['model']:
print(f'warning: did not find model name entry in setup, using pitchpostfilter per default')
model_name = 'pitchpostfilter'
else:
model_name = checkpoint['setup']['model']['name']
model = model_dict[model_name](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
model.load_state_dict(checkpoint['state_dict'])
# generate model input
setup = checkpoint['setup']
testdata = load_lpcnet_features(input_folder)
features = testdata['features']
periods = testdata['periods']
if args.debug:
endoscopy.init()
start = time()
output = model.process(features, periods, debug=args.debug)
elapsed = time() - start
print(f"[timing] inference took {elapsed * 1000} ms")
wavfile.write(output_file, 16000, output.cpu().numpy())
if args.debug:
endoscopy.close()

View file

@ -0,0 +1,287 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import sys
import yaml
try:
import git
has_git = True
except:
has_git = False
import torch
from torch.optim.lr_scheduler import LambdaLR
from scipy.io import wavfile
import pesq
from data import LPCNetVocodingDataset
from models import model_dict
from engine.vocoder_engine import train_one_epoch, evaluate
from utils.lpcnet_features import load_lpcnet_features
from utils.misc import count_parameters
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
parser = argparse.ArgumentParser()
parser.add_argument('setup', type=str, help='setup yaml file')
parser.add_argument('output', type=str, help='output path')
parser.add_argument('--device', type=str, help='compute device', default=None)
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
parser.add_argument('--test-features', type=str, help='path to features for testing', default=None)
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
args = parser.parse_args()
torch.set_num_threads(4)
with open(args.setup, 'r') as f:
setup = yaml.load(f.read(), yaml.FullLoader)
checkpoint_prefix = 'checkpoint'
output_prefix = 'output'
setup_name = 'setup.yml'
output_file='out.txt'
# check model
if not 'name' in setup['model']:
print(f'warning: did not find model entry in setup, using default PitchPostFilter')
model_name = 'pitchpostfilter'
else:
model_name = setup['model']['name']
# prepare output folder
if os.path.exists(args.output):
print("warning: output folder exists")
reply = input('continue? (y/n): ')
while reply not in {'y', 'n'}:
reply = input('continue? (y/n): ')
if reply == 'n':
os._exit()
else:
os.makedirs(args.output, exist_ok=True)
checkpoint_dir = os.path.join(args.output, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
# add repo info to setup
if has_git:
working_dir = os.path.split(__file__)[0]
try:
repo = git.Repo(working_dir)
setup['repo'] = dict()
hash = repo.head.object.hexsha
urls = list(repo.remote().urls)
is_dirty = repo.is_dirty()
if is_dirty:
print("warning: repo is dirty")
setup['repo']['hash'] = hash
setup['repo']['urls'] = urls
setup['repo']['dirty'] = is_dirty
except:
has_git = False
# dump setup
with open(os.path.join(args.output, setup_name), 'w') as f:
yaml.dump(setup, f)
ref = None
# prepare inference test if wanted
inference_test = False
if type(args.test_features) != type(None):
test_features = load_lpcnet_features(args.test_features)
features = test_features['features']
periods = test_features['periods']
inference_folder = os.path.join(args.output, 'inference_test')
os.makedirs(inference_folder, exist_ok=True)
inference_test = True
# training parameters
batch_size = setup['training']['batch_size']
epochs = setup['training']['epochs']
lr = setup['training']['lr']
lr_decay_factor = setup['training']['lr_decay_factor']
# load training dataset
data_config = setup['data']
data = LPCNetVocodingDataset(setup['dataset'], **data_config)
# load validation dataset if given
if 'validation_dataset' in setup:
validation_data = LPCNetVocodingDataset(setup['validation_dataset'], **data_config)
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=8)
run_validation = True
else:
run_validation = False
# create model
model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
if args.initial_checkpoint is not None:
print(f"loading state dict from {args.initial_checkpoint}...")
chkpt = torch.load(args.initial_checkpoint, map_location='cpu')
model.load_state_dict(chkpt['state_dict'])
# set compute device
if type(args.device) == type(None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
# push model to device
model.to(device)
# dataloader
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=8)
# optimizer is introduced to trainable parameters
parameters = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(parameters, lr=lr)
# learning rate scheduler
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
# loss
w_l1 = setup['training']['loss']['w_l1']
w_lm = setup['training']['loss']['w_lm']
w_slm = setup['training']['loss']['w_slm']
w_sc = setup['training']['loss']['w_sc']
w_logmel = setup['training']['loss']['w_logmel']
w_wsc = setup['training']['loss']['w_wsc']
w_xcorr = setup['training']['loss']['w_xcorr']
w_sxcorr = setup['training']['loss']['w_sxcorr']
w_l2 = setup['training']['loss']['w_l2']
w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
logmelloss = MRLogMelLoss().to(device)
def xcorr_loss(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
return torch.mean(loss)
def td_l2_norm(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
return loss.mean()
def td_l1(y_true, y_pred, pow=0):
dims = list(range(1, len(y_true.shape)))
tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
return torch.mean(tmp)
def criterion(x, y):
return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
# model checkpoint
checkpoint = {
'setup' : setup,
'state_dict' : model.state_dict(),
'loss' : -1
}
if not args.no_redirect:
print(f"re-directing output to {os.path.join(args.output, output_file)}")
sys.stdout = open(os.path.join(args.output, output_file), "w")
print("summary:")
print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
if hasattr(model, 'flop_count'):
print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS")
if ref is not None:
pass
best_loss = 1e9
for ep in range(1, epochs + 1):
print(f"training epoch {ep}...")
new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler)
# save checkpoint
checkpoint['state_dict'] = model.state_dict()
checkpoint['loss'] = new_loss
if run_validation:
print("running validation...")
validation_loss = evaluate(model, criterion, validation_dataloader, device)
checkpoint['validation_loss'] = validation_loss
if validation_loss < best_loss:
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth'))
best_loss = validation_loss
if inference_test:
print("running inference test...")
out = model.process(features, periods).cpu().numpy()
wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
if ref is not None:
mos = pesq.pesq(16000, ref, out, mode='wb')
print(f"MOS (PESQ): {mos}")
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
print()
print('Done')

View file

@ -1,31 +1,4 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
def _conv1d_flop_count(layer, rate): def _conv1d_flop_count(layer, rate):
return 2 * ((layer.in_channels + 1) * layer.out_channels * rate / layer.stride[0] ) * layer.kernel_size[0] return 2 * ((layer.in_channels + 1) * layer.out_channels * rate / layer.stride[0] ) * layer.kernel_size[0]

View file

@ -1,32 +1,3 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
""" module for inspecting models during inference """ """ module for inspecting models during inference """
import os import os

View file

@ -0,0 +1,100 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
from utils.complexity import _conv1d_flop_count
class NoiseShaper(nn.Module):
def __init__(self,
feature_dim,
frame_size=160
):
"""
Parameters:
-----------
feature_dim : int
dimension of input features
frame_size : int
frame size
"""
super().__init__()
self.feature_dim = feature_dim
self.frame_size = frame_size
# feature transform
self.feature_alpha1 = nn.Conv1d(self.feature_dim, frame_size, 2)
self.feature_alpha2 = nn.Conv1d(frame_size, frame_size, 2)
def flop_count(self, rate):
frame_rate = rate / self.frame_size
shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size
return shape_flops
def forward(self, features):
""" creates temporally shaped noise
Parameters:
-----------
features : torch.tensor
frame-wise features of shape (batch_size, num_frames, feature_dim)
"""
batch_size = features.size(0)
num_frames = features.size(1)
frame_size = self.frame_size
num_samples = num_frames * frame_size
# feature path
f = F.pad(features.permute(0, 2, 1), [1, 0])
alpha = F.leaky_relu(self.feature_alpha1(f), 0.2)
alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0])))
alpha = alpha.permute(0, 2, 1)
# signal generation
y = torch.randn((batch_size, num_frames, frame_size), dtype=features.dtype, device=features.device)
y = alpha * y
return y.reshape(batch_size, 1, num_samples)

View file

@ -1,3 +1,32 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
""" This module implements the SILK upsampler from 16kHz to 24 or 48 kHz """ """ This module implements the SILK upsampler from 16kHz to 24 or 48 kHz """
import torch import torch

View file

@ -11,7 +11,8 @@ class TDShaper(nn.Module):
feature_dim, feature_dim,
frame_size=160, frame_size=160,
avg_pool_k=4, avg_pool_k=4,
innovate=False innovate=False,
pool_after=False
): ):
""" """
@ -39,6 +40,7 @@ class TDShaper(nn.Module):
self.frame_size = frame_size self.frame_size = frame_size
self.avg_pool_k = avg_pool_k self.avg_pool_k = avg_pool_k
self.innovate = innovate self.innovate = innovate
self.pool_after = pool_after
assert frame_size % avg_pool_k == 0 assert frame_size % avg_pool_k == 0
self.env_dim = frame_size // avg_pool_k + 1 self.env_dim = frame_size // avg_pool_k + 1
@ -71,6 +73,10 @@ class TDShaper(nn.Module):
def envelope_transform(self, x): def envelope_transform(self, x):
x = torch.abs(x) x = torch.abs(x)
if self.pool_after:
x = torch.log(x + .5**16)
x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k)
else:
x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k) x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k)
x = torch.log(x + .5**16) x = torch.log(x + .5**16)

View file

@ -0,0 +1,112 @@
import os
import torch
import numpy as np
def load_lpcnet_features(feature_file, version=2):
if version == 2:
layout = {
'cepstrum': [0,18],
'periods': [18, 19],
'pitch_corr': [19, 20],
'lpc': [20, 36]
}
frame_length = 36
elif version == 1:
layout = {
'cepstrum': [0,18],
'periods': [36, 37],
'pitch_corr': [37, 38],
'lpc': [39, 55],
}
frame_length = 55
else:
raise ValueError(f'unknown feature version: {version}')
raw_features = torch.from_numpy(np.fromfile(feature_file, dtype='float32'))
raw_features = raw_features.reshape((-1, frame_length))
features = torch.cat(
[
raw_features[:, layout['cepstrum'][0] : layout['cepstrum'][1]],
raw_features[:, layout['pitch_corr'][0] : layout['pitch_corr'][1]]
],
dim=1
)
lpcs = raw_features[:, layout['lpc'][0] : layout['lpc'][1]]
periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long()
return {'features' : features, 'periods' : periods, 'lpcs' : lpcs}
def create_new_data(signal_path, reference_data_path, new_data_path, offset=320, preemph_factor=0.85):
ref_data = np.memmap(reference_data_path, dtype=np.int16)
signal = np.memmap(signal_path, dtype=np.int16)
signal_preemph_path = os.path.splitext(signal_path)[0] + '_preemph.raw'
signal_preemph = np.memmap(signal_preemph_path, dtype=np.int16, mode='write', shape=signal.shape)
assert len(signal) % 160 == 0
num_frames = len(signal) // 160
mem = np.zeros(1)
for fr in range(len(signal)//160):
signal_preemph[fr * 160 : (fr + 1) * 160] = np.convolve(np.concatenate((mem, signal[fr * 160 : (fr + 1) * 160])), [1, -preemph_factor], mode='valid')
mem = signal[(fr + 1) * 160 - 1 : (fr + 1) * 160]
new_data = np.memmap(new_data_path, dtype=np.int16, mode='write', shape=ref_data.shape)
new_data[:] = 0
N = len(signal) - offset
new_data[1 : 2*N + 1: 2] = signal_preemph[offset:]
new_data[2 : 2*N + 2: 2] = signal_preemph[offset:]
def parse_warpq_scores(output_file):
""" extracts warpq scores from output file """
with open(output_file, "r") as f:
lines = f.readlines()
scores = [float(line.split("WARP-Q score:")[-1]) for line in lines if line.startswith("WARP-Q score:")]
return scores
def parse_stats_file(file):
with open(file, "r") as f:
lines = f.readlines()
mean = float(lines[0].split(":")[-1])
bt_mean = float(lines[1].split(":")[-1])
top_mean = float(lines[2].split(":")[-1])
return mean, bt_mean, top_mean
def collect_test_stats(test_folder):
""" collects statistics for all discovered metrics from test folder """
metrics = {'pesq', 'warpq', 'pitch_error', 'voicing_error'}
results = dict()
content = os.listdir(test_folder)
stats_files = [file for file in content if file.startswith('stats_')]
for file in stats_files:
metric = file[len("stats_") : -len(".txt")]
if metric not in metrics:
print(f"warning: unknown metric {metric}")
mean, bt_mean, top_mean = parse_stats_file(os.path.join(test_folder, file))
results[metric] = [mean, bt_mean, top_mean]
return results

View file

@ -40,3 +40,26 @@ def count_parameters(model, verbose=False):
total += count total += count
return total return total
def retain_grads(module):
for p in module.parameters():
if p.requires_grad:
p.retain_grad()
def get_grad_norm(module, p=2):
norm = 0
for param in module.parameters():
if param.requires_grad:
norm = norm + (torch.abs(param.grad) ** p).sum()
return norm ** (1/p)
def create_weights(s_real, s_gen, alpha):
weights = []
with torch.no_grad():
for sr, sg in zip(s_real, s_gen):
weight = torch.exp(alpha * (sr[-1] - sg[-1]))
weights.append(weight)
return weights

View file

@ -27,7 +27,6 @@
*/ */
""" """
import os import os
import numpy as np import numpy as np

View file

@ -30,6 +30,7 @@
import math as m import math as m
import numpy as np import numpy as np
import scipy import scipy
import torch
def erb(f): def erb(f):
return 24.7 * (4.37 * f + 1) return 24.7 * (4.37 * f + 1)
@ -49,6 +50,20 @@ scale_dict = {
'erb': [erb, inv_erb] 'erb': [erb, inv_erb]
} }
def gen_filterbank(N, Fs=16000, keep_size=False):
in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
M = N + 1 if keep_size else N
out_freq = (np.arange(M, dtype='float32')/N*Fs/2)[:,None]
#ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
ERB_N = 24.7 + .108*in_freq
delta = np.abs(in_freq-out_freq)/ERB_N
center = (delta<.5).astype('float32')
R = -12*center*delta**2 + (1-center)*(3-12*delta)
RE = 10.**(R/10.)
norm = np.sum(RE, axis=1)
RE = RE/norm[:, np.newaxis]
return torch.from_numpy(RE)
def create_filter_bank(num_bands, n_fft=320, fs=16000, scale='bark', round_center_bins=False, return_upper=False, normalize=False): def create_filter_bank(num_bands, n_fft=320, fs=16000, scale='bark', round_center_bins=False, return_upper=False, normalize=False):
f0 = 0 f0 = 0

View file

@ -140,8 +140,196 @@ nolace_setup = {
} }
} }
nolace_setup_adv = {
'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training',
'model': {
'name': 'nolace',
'args': [],
'kwargs': {
'avg_pool_k': 4,
'comb_gain_limit_db': 10,
'cond_dim': 256,
'conv_gain_limits_db': [-12, 12],
'global_gain_limits_db': [-6, 6],
'hidden_feature_dim': 96,
'kernel_size': 15,
'num_features': 93,
'numbits_embedding_dim': 8,
'numbits_range': [50, 650],
'partial_lookahead': True,
'pitch_embedding_dim': 64,
'pitch_max': 300,
'preemph': 0.85,
'skip': 91
}
},
'data': {
'frames_per_sample': 100,
'no_pitch_value': 7,
'preemph': 0.85,
'skip': 91,
'pitch_hangover': 8,
'acorr_radius': 2,
'num_bands_clean_spec': 64,
'num_bands_noisy_spec': 18,
'noisy_spec_scale': 'opus',
'pitch_hangover': 8,
},
'discriminator': {
'args': [],
'kwargs': {
'architecture': 'free',
'design': 'f_down',
'fft_sizes_16k': [
64,
128,
256,
512,
1024,
2048,
],
'freq_roi': [0, 7400],
'fs': 16000,
'max_channels': 256,
'noise_gain': 0.0,
},
'name': 'fdmresdisc',
},
'training': {
'adv_target': 'target_orig',
'batch_size': 64,
'epochs': 50,
'gen_lr_reduction': 1,
'lambda_feat': 1.0,
'lambda_reg': 0.6,
'loss': {
'w_l1': 0,
'w_l2': 10,
'w_lm': 0,
'w_logmel': 0,
'w_sc': 0,
'w_slm': 20,
'w_sxcorr': 1,
'w_wsc': 0,
'w_xcorr': 0,
},
'lr': 0.0001,
'lr_decay_factor': 2.5e-09,
}
}
lavoce_setup = {
'data': {
'frames_per_sample': 100,
'target': 'signal'
},
'dataset': '/local/datasets/lpcnet_large/training',
'model': {
'args': [],
'kwargs': {
'comb_gain_limit_db': 10,
'cond_dim': 256,
'conv_gain_limits_db': [-12, 12],
'global_gain_limits_db': [-6, 6],
'kernel_size': 15,
'num_features': 19,
'pitch_embedding_dim': 64,
'pitch_max': 300,
'preemph': 0.85,
'pulses': True
},
'name': 'lavoce'
},
'training': {
'batch_size': 256,
'epochs': 50,
'loss': {
'w_l1': 0,
'w_l2': 0,
'w_lm': 0,
'w_logmel': 0,
'w_sc': 0,
'w_slm': 2,
'w_sxcorr': 1,
'w_wsc': 0,
'w_xcorr': 0
},
'lr': 0.0005,
'lr_decay_factor': 2.5e-05
},
'validation_dataset': '/local/datasets/lpcnet_large/validation'
}
lavoce_setup_adv = {
'data': {
'frames_per_sample': 100,
'target': 'signal'
},
'dataset': '/local/datasets/lpcnet_large/training',
'discriminator': {
'args': [],
'kwargs': {
'architecture': 'free',
'design': 'f_down',
'fft_sizes_16k': [
64,
128,
256,
512,
1024,
2048,
],
'freq_roi': [0, 7400],
'fs': 16000,
'max_channels': 256,
'noise_gain': 0.0,
},
'name': 'fdmresdisc',
},
'model': {
'args': [],
'kwargs': {
'comb_gain_limit_db': 10,
'cond_dim': 256,
'conv_gain_limits_db': [-12, 12],
'global_gain_limits_db': [-6, 6],
'kernel_size': 15,
'num_features': 19,
'pitch_embedding_dim': 64,
'pitch_max': 300,
'preemph': 0.85,
'pulses': True
},
'name': 'lavoce'
},
'training': {
'batch_size': 64,
'epochs': 50,
'gen_lr_reduction': 1,
'lambda_feat': 1.0,
'lambda_reg': 0.6,
'loss': {
'w_l1': 0,
'w_l2': 0,
'w_lm': 0,
'w_logmel': 0,
'w_sc': 0,
'w_slm': 2,
'w_sxcorr': 1,
'w_wsc': 0,
'w_xcorr': 0
},
'lr': 0.0001,
'lr_decay_factor': 2.5e-09
},
}
setup_dict = { setup_dict = {
'lace': lace_setup, 'lace': lace_setup,
'nolace': nolace_setup 'nolace': nolace_setup,
'nolace_adv': nolace_setup_adv,
'lavoce': lavoce_setup,
'lavoce_adv': lavoce_setup_adv
} }