mirror of
https://github.com/xiph/opus.git
synced 2025-05-23 19:59:12 +00:00
added more enhancement stuff
Signed-off-by: Jan Buethe <jbuethe@amazon.de>
This commit is contained in:
parent
7b8ba143f1
commit
2f290d32ed
24 changed files with 3511 additions and 108 deletions
458
dnn/torch/osce/adv_train_model.py
Normal file
458
dnn/torch/osce/adv_train_model.py
Normal file
|
@ -0,0 +1,458 @@
|
||||||
|
"""
|
||||||
|
/* Copyright (c) 2023 Amazon
|
||||||
|
Written by Jan Buethe */
|
||||||
|
/*
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions
|
||||||
|
are met:
|
||||||
|
|
||||||
|
- Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
- Redistributions in binary form must reproduce the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer in the
|
||||||
|
documentation and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||||
|
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||||
|
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||||
|
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*/
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
import math as m
|
||||||
|
import random
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
try:
|
||||||
|
import git
|
||||||
|
has_git = True
|
||||||
|
except:
|
||||||
|
has_git = False
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from scipy.io import wavfile
|
||||||
|
import numpy as np
|
||||||
|
import pesq
|
||||||
|
|
||||||
|
from data import SilkEnhancementSet
|
||||||
|
from models import model_dict
|
||||||
|
|
||||||
|
|
||||||
|
from utils.silk_features import load_inference_data
|
||||||
|
from utils.misc import count_parameters, retain_grads, get_grad_norm, create_weights
|
||||||
|
|
||||||
|
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument('setup', type=str, help='setup yaml file')
|
||||||
|
parser.add_argument('output', type=str, help='output path')
|
||||||
|
parser.add_argument('--device', type=str, help='compute device', default=None)
|
||||||
|
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
|
||||||
|
parser.add_argument('--testdata', type=str, help='path to features and signal for testing', default=None)
|
||||||
|
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(4)
|
||||||
|
|
||||||
|
with open(args.setup, 'r') as f:
|
||||||
|
setup = yaml.load(f.read(), yaml.FullLoader)
|
||||||
|
|
||||||
|
checkpoint_prefix = 'checkpoint'
|
||||||
|
output_prefix = 'output'
|
||||||
|
setup_name = 'setup.yml'
|
||||||
|
output_file='out.txt'
|
||||||
|
|
||||||
|
|
||||||
|
# check model
|
||||||
|
if not 'name' in setup['model']:
|
||||||
|
print(f'warning: did not find model entry in setup, using default PitchPostFilter')
|
||||||
|
model_name = 'pitchpostfilter'
|
||||||
|
else:
|
||||||
|
model_name = setup['model']['name']
|
||||||
|
|
||||||
|
# prepare output folder
|
||||||
|
if os.path.exists(args.output):
|
||||||
|
print("warning: output folder exists")
|
||||||
|
|
||||||
|
reply = input('continue? (y/n): ')
|
||||||
|
while reply not in {'y', 'n'}:
|
||||||
|
reply = input('continue? (y/n): ')
|
||||||
|
|
||||||
|
if reply == 'n':
|
||||||
|
os._exit()
|
||||||
|
else:
|
||||||
|
os.makedirs(args.output, exist_ok=True)
|
||||||
|
|
||||||
|
checkpoint_dir = os.path.join(args.output, 'checkpoints')
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# add repo info to setup
|
||||||
|
if has_git:
|
||||||
|
working_dir = os.path.split(__file__)[0]
|
||||||
|
try:
|
||||||
|
repo = git.Repo(working_dir)
|
||||||
|
setup['repo'] = dict()
|
||||||
|
hash = repo.head.object.hexsha
|
||||||
|
urls = list(repo.remote().urls)
|
||||||
|
is_dirty = repo.is_dirty()
|
||||||
|
|
||||||
|
if is_dirty:
|
||||||
|
print("warning: repo is dirty")
|
||||||
|
|
||||||
|
setup['repo']['hash'] = hash
|
||||||
|
setup['repo']['urls'] = urls
|
||||||
|
setup['repo']['dirty'] = is_dirty
|
||||||
|
except:
|
||||||
|
has_git = False
|
||||||
|
|
||||||
|
# dump setup
|
||||||
|
with open(os.path.join(args.output, setup_name), 'w') as f:
|
||||||
|
yaml.dump(setup, f)
|
||||||
|
|
||||||
|
|
||||||
|
ref = None
|
||||||
|
if args.testdata is not None:
|
||||||
|
|
||||||
|
testsignal, features, periods, numbits = load_inference_data(args.testdata, **setup['data'])
|
||||||
|
|
||||||
|
inference_test = True
|
||||||
|
inference_folder = os.path.join(args.output, 'inference_test')
|
||||||
|
os.makedirs(os.path.join(args.output, 'inference_test'), exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
ref = np.fromfile(os.path.join(args.testdata, 'clean.s16'), dtype=np.int16)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
inference_test = False
|
||||||
|
|
||||||
|
# training parameters
|
||||||
|
batch_size = setup['training']['batch_size']
|
||||||
|
epochs = setup['training']['epochs']
|
||||||
|
lr = setup['training']['lr']
|
||||||
|
lr_decay_factor = setup['training']['lr_decay_factor']
|
||||||
|
lr_gen = lr * setup['training']['gen_lr_reduction']
|
||||||
|
lambda_feat = setup['training']['lambda_feat']
|
||||||
|
lambda_reg = setup['training']['lambda_reg']
|
||||||
|
adv_target = setup['training'].get('adv_target', 'target')
|
||||||
|
|
||||||
|
# load training dataset
|
||||||
|
data_config = setup['data']
|
||||||
|
data = SilkEnhancementSet(setup['dataset'], **data_config)
|
||||||
|
|
||||||
|
# load validation dataset if given
|
||||||
|
if 'validation_dataset' in setup:
|
||||||
|
validation_data = SilkEnhancementSet(setup['validation_dataset'], **data_config)
|
||||||
|
|
||||||
|
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)
|
||||||
|
|
||||||
|
run_validation = True
|
||||||
|
else:
|
||||||
|
run_validation = False
|
||||||
|
|
||||||
|
# create model
|
||||||
|
model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
|
||||||
|
|
||||||
|
# create discriminator
|
||||||
|
disc_name = setup['discriminator']['name']
|
||||||
|
disc = model_dict[disc_name](
|
||||||
|
*setup['discriminator']['args'], **setup['discriminator']['kwargs']
|
||||||
|
)
|
||||||
|
|
||||||
|
# set compute device
|
||||||
|
if type(args.device) == type(None):
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
else:
|
||||||
|
device = torch.device(args.device)
|
||||||
|
|
||||||
|
# dataloader
|
||||||
|
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
|
||||||
|
|
||||||
|
# optimizer is introduced to trainable parameters
|
||||||
|
parameters = [p for p in model.parameters() if p.requires_grad]
|
||||||
|
optimizer = torch.optim.Adam(parameters, lr=lr_gen)
|
||||||
|
|
||||||
|
# disc optimizer
|
||||||
|
parameters = [p for p in disc.parameters() if p.requires_grad]
|
||||||
|
optimizer_disc = torch.optim.Adam(parameters, lr=lr, betas=[0.5, 0.9])
|
||||||
|
|
||||||
|
# learning rate scheduler
|
||||||
|
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
|
||||||
|
|
||||||
|
if args.initial_checkpoint is not None:
|
||||||
|
print(f"loading state dict from {args.initial_checkpoint}...")
|
||||||
|
chkpt = torch.load(args.initial_checkpoint, map_location=device)
|
||||||
|
model.load_state_dict(chkpt['state_dict'])
|
||||||
|
|
||||||
|
if 'disc_state_dict' in chkpt:
|
||||||
|
print(f"loading discriminator state dict from {args.initial_checkpoint}...")
|
||||||
|
disc.load_state_dict(chkpt['disc_state_dict'])
|
||||||
|
|
||||||
|
if 'optimizer_state_dict' in chkpt:
|
||||||
|
print(f"loading optimizer state dict from {args.initial_checkpoint}...")
|
||||||
|
optimizer.load_state_dict(chkpt['optimizer_state_dict'])
|
||||||
|
|
||||||
|
if 'disc_optimizer_state_dict' in chkpt:
|
||||||
|
print(f"loading discriminator optimizer state dict from {args.initial_checkpoint}...")
|
||||||
|
optimizer_disc.load_state_dict(chkpt['disc_optimizer_state_dict'])
|
||||||
|
|
||||||
|
if 'scheduler_state_disc' in chkpt:
|
||||||
|
print(f"loading scheduler state dict from {args.initial_checkpoint}...")
|
||||||
|
scheduler.load_state_dict(chkpt['scheduler_state_dict'])
|
||||||
|
|
||||||
|
# if 'torch_rng_state' in chkpt:
|
||||||
|
# print(f"setting torch RNG state from {args.initial_checkpoint}...")
|
||||||
|
# torch.set_rng_state(chkpt['torch_rng_state'])
|
||||||
|
|
||||||
|
if 'numpy_rng_state' in chkpt:
|
||||||
|
print(f"setting numpy RNG state from {args.initial_checkpoint}...")
|
||||||
|
np.random.set_state(chkpt['numpy_rng_state'])
|
||||||
|
|
||||||
|
if 'python_rng_state' in chkpt:
|
||||||
|
print(f"setting Python RNG state from {args.initial_checkpoint}...")
|
||||||
|
random.setstate(chkpt['python_rng_state'])
|
||||||
|
|
||||||
|
# loss
|
||||||
|
w_l1 = setup['training']['loss']['w_l1']
|
||||||
|
w_lm = setup['training']['loss']['w_lm']
|
||||||
|
w_slm = setup['training']['loss']['w_slm']
|
||||||
|
w_sc = setup['training']['loss']['w_sc']
|
||||||
|
w_logmel = setup['training']['loss']['w_logmel']
|
||||||
|
w_wsc = setup['training']['loss']['w_wsc']
|
||||||
|
w_xcorr = setup['training']['loss']['w_xcorr']
|
||||||
|
w_sxcorr = setup['training']['loss']['w_sxcorr']
|
||||||
|
w_l2 = setup['training']['loss']['w_l2']
|
||||||
|
|
||||||
|
w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
|
||||||
|
|
||||||
|
stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
|
||||||
|
logmelloss = MRLogMelLoss().to(device)
|
||||||
|
|
||||||
|
def xcorr_loss(y_true, y_pred):
|
||||||
|
dims = list(range(1, len(y_true.shape)))
|
||||||
|
|
||||||
|
loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
|
||||||
|
|
||||||
|
return torch.mean(loss)
|
||||||
|
|
||||||
|
def td_l2_norm(y_true, y_pred):
|
||||||
|
dims = list(range(1, len(y_true.shape)))
|
||||||
|
|
||||||
|
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
|
||||||
|
|
||||||
|
return loss.mean()
|
||||||
|
|
||||||
|
def td_l1(y_true, y_pred, pow=0):
|
||||||
|
dims = list(range(1, len(y_true.shape)))
|
||||||
|
tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
|
||||||
|
|
||||||
|
return torch.mean(tmp)
|
||||||
|
|
||||||
|
def criterion(x, y):
|
||||||
|
|
||||||
|
return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
|
||||||
|
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
|
||||||
|
|
||||||
|
|
||||||
|
# model checkpoint
|
||||||
|
checkpoint = {
|
||||||
|
'setup' : setup,
|
||||||
|
'state_dict' : model.state_dict(),
|
||||||
|
'loss' : -1
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if not args.no_redirect:
|
||||||
|
print(f"re-directing output to {os.path.join(args.output, output_file)}")
|
||||||
|
sys.stdout = open(os.path.join(args.output, output_file), "w")
|
||||||
|
|
||||||
|
|
||||||
|
print("summary:")
|
||||||
|
|
||||||
|
print(f"generator: {count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
|
||||||
|
if hasattr(model, 'flop_count'):
|
||||||
|
print(f"generator: {model.flop_count(16000) / 1e6:5.3f} MFLOPS")
|
||||||
|
print(f"discriminator: {count_parameters(disc.cpu()) / 1e6:5.3f} M parameters")
|
||||||
|
|
||||||
|
if ref is not None:
|
||||||
|
noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16)
|
||||||
|
initial_mos = pesq.pesq(16000, ref, noisy, mode='wb')
|
||||||
|
print(f"initial MOS (PESQ): {initial_mos}")
|
||||||
|
|
||||||
|
best_loss = 1e9
|
||||||
|
log_interval = 10
|
||||||
|
|
||||||
|
|
||||||
|
m_r = 0
|
||||||
|
m_f = 0
|
||||||
|
s_r = 1
|
||||||
|
s_f = 1
|
||||||
|
|
||||||
|
def optimizer_to(optim, device):
|
||||||
|
for param in optim.state.values():
|
||||||
|
if isinstance(param, torch.Tensor):
|
||||||
|
param.data = param.data.to(device)
|
||||||
|
if param._grad is not None:
|
||||||
|
param._grad.data = param._grad.data.to(device)
|
||||||
|
elif isinstance(param, dict):
|
||||||
|
for subparam in param.values():
|
||||||
|
if isinstance(subparam, torch.Tensor):
|
||||||
|
subparam.data = subparam.data.to(device)
|
||||||
|
if subparam._grad is not None:
|
||||||
|
subparam._grad.data = subparam._grad.data.to(device)
|
||||||
|
|
||||||
|
optimizer_to(optimizer, device)
|
||||||
|
optimizer_to(optimizer_disc, device)
|
||||||
|
|
||||||
|
retain_grads(model)
|
||||||
|
retain_grads(disc)
|
||||||
|
|
||||||
|
for ep in range(1, epochs + 1):
|
||||||
|
print(f"training epoch {ep}...")
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
disc.to(device)
|
||||||
|
model.train()
|
||||||
|
disc.train()
|
||||||
|
|
||||||
|
running_disc_loss = 0
|
||||||
|
running_adv_loss = 0
|
||||||
|
running_feature_loss = 0
|
||||||
|
running_reg_loss = 0
|
||||||
|
running_disc_grad_norm = 0
|
||||||
|
running_model_grad_norm = 0
|
||||||
|
|
||||||
|
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||||
|
for i, batch in enumerate(tepoch):
|
||||||
|
|
||||||
|
# set gradients to zero
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# push batch to device
|
||||||
|
for key in batch:
|
||||||
|
batch[key] = batch[key].to(device)
|
||||||
|
|
||||||
|
target = batch['target'].to(device)
|
||||||
|
disc_target = batch[adv_target].to(device)
|
||||||
|
|
||||||
|
# calculate model output
|
||||||
|
output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
|
||||||
|
|
||||||
|
# discriminator update
|
||||||
|
scores_gen = disc(output.detach())
|
||||||
|
scores_real = disc(disc_target.unsqueeze(1))
|
||||||
|
|
||||||
|
disc_loss = 0
|
||||||
|
for score in scores_gen:
|
||||||
|
disc_loss += (((score[-1]) ** 2)).mean()
|
||||||
|
m_f = 0.9 * m_f + 0.1 * score[-1].detach().mean().cpu().item()
|
||||||
|
s_f = 0.9 * s_f + 0.1 * score[-1].detach().std().cpu().item()
|
||||||
|
|
||||||
|
for score in scores_real:
|
||||||
|
disc_loss += (((1 - score[-1]) ** 2)).mean()
|
||||||
|
m_r = 0.9 * m_r + 0.1 * score[-1].detach().mean().cpu().item()
|
||||||
|
s_r = 0.9 * s_r + 0.1 * score[-1].detach().std().cpu().item()
|
||||||
|
|
||||||
|
disc_loss = 0.5 * disc_loss / len(scores_gen)
|
||||||
|
winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) )
|
||||||
|
|
||||||
|
disc.zero_grad()
|
||||||
|
disc_loss.backward()
|
||||||
|
|
||||||
|
running_disc_grad_norm += get_grad_norm(disc).detach().cpu().item()
|
||||||
|
|
||||||
|
optimizer_disc.step()
|
||||||
|
|
||||||
|
# generator update
|
||||||
|
scores_gen = disc(output)
|
||||||
|
|
||||||
|
# calculate loss
|
||||||
|
loss_reg = criterion(output.squeeze(1), target)
|
||||||
|
|
||||||
|
num_discs = len(scores_gen)
|
||||||
|
gen_loss = 0
|
||||||
|
for score in scores_gen:
|
||||||
|
gen_loss += (((1 - score[-1]) ** 2)).mean() / num_discs
|
||||||
|
|
||||||
|
loss_feat = 0
|
||||||
|
for k in range(num_discs):
|
||||||
|
num_layers = len(scores_gen[k]) - 1
|
||||||
|
f = 4 / num_discs / num_layers
|
||||||
|
for l in range(num_layers):
|
||||||
|
loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach())
|
||||||
|
|
||||||
|
model.zero_grad()
|
||||||
|
|
||||||
|
(gen_loss + lambda_feat * loss_feat + lambda_reg * loss_reg).backward()
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_model_grad_norm += get_grad_norm(model).detach().cpu().item()
|
||||||
|
running_adv_loss += gen_loss.detach().cpu().item()
|
||||||
|
running_disc_loss += disc_loss.detach().cpu().item()
|
||||||
|
running_feature_loss += lambda_feat * loss_feat.detach().cpu().item()
|
||||||
|
running_reg_loss += lambda_reg * loss_reg.detach().cpu().item()
|
||||||
|
|
||||||
|
# update status bar
|
||||||
|
if i % log_interval == 0:
|
||||||
|
tepoch.set_postfix(adv_loss=f"{running_adv_loss/(i + 1):8.7f}",
|
||||||
|
disc_loss=f"{running_disc_loss/(i + 1):8.7f}",
|
||||||
|
feat_loss=f"{running_feature_loss/(i + 1):8.7f}",
|
||||||
|
reg_loss=f"{running_reg_loss/(i + 1):8.7f}",
|
||||||
|
model_gradnorm=f"{running_model_grad_norm/(i+1):8.7f}",
|
||||||
|
disc_gradnorm=f"{running_disc_grad_norm/(i+1):8.7f}",
|
||||||
|
wc=f"{100*winning_chance:5.2f}%")
|
||||||
|
|
||||||
|
|
||||||
|
# save checkpoint
|
||||||
|
checkpoint['state_dict'] = model.state_dict()
|
||||||
|
checkpoint['disc_state_dict'] = disc.state_dict()
|
||||||
|
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
|
||||||
|
checkpoint['disc_optimizer_state_dict'] = optimizer_disc.state_dict()
|
||||||
|
checkpoint['scheduler_state_dict'] = scheduler.state_dict()
|
||||||
|
checkpoint['torch_rng_state'] = torch.get_rng_state()
|
||||||
|
checkpoint['numpy_rng_state'] = np.random.get_state()
|
||||||
|
checkpoint['python_rng_state'] = random.getstate()
|
||||||
|
checkpoint['adv_loss'] = running_adv_loss/(i + 1)
|
||||||
|
checkpoint['disc_loss'] = running_disc_loss/(i + 1)
|
||||||
|
checkpoint['feature_loss'] = running_feature_loss/(i + 1)
|
||||||
|
checkpoint['reg_loss'] = running_reg_loss/(i + 1)
|
||||||
|
|
||||||
|
|
||||||
|
if inference_test:
|
||||||
|
print("running inference test...")
|
||||||
|
out = model.process(testsignal, features, periods, numbits).cpu().numpy()
|
||||||
|
wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
|
||||||
|
if ref is not None:
|
||||||
|
mos = pesq.pesq(16000, ref, out, mode='wb')
|
||||||
|
print(f"MOS (PESQ): {mos}")
|
||||||
|
|
||||||
|
|
||||||
|
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
|
||||||
|
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
|
||||||
|
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
print('Done')
|
451
dnn/torch/osce/adv_train_vocoder.py
Normal file
451
dnn/torch/osce/adv_train_vocoder.py
Normal file
|
@ -0,0 +1,451 @@
|
||||||
|
"""
|
||||||
|
/* Copyright (c) 2023 Amazon
|
||||||
|
Written by Jan Buethe */
|
||||||
|
/*
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions
|
||||||
|
are met:
|
||||||
|
|
||||||
|
- Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
- Redistributions in binary form must reproduce the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer in the
|
||||||
|
documentation and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||||
|
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||||
|
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||||
|
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*/
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
import math as m
|
||||||
|
import random
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
try:
|
||||||
|
import git
|
||||||
|
has_git = True
|
||||||
|
except:
|
||||||
|
has_git = False
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from scipy.io import wavfile
|
||||||
|
import numpy as np
|
||||||
|
import pesq
|
||||||
|
|
||||||
|
from data import LPCNetVocodingDataset
|
||||||
|
from models import model_dict
|
||||||
|
|
||||||
|
|
||||||
|
from utils.lpcnet_features import load_lpcnet_features
|
||||||
|
from utils.misc import count_parameters
|
||||||
|
|
||||||
|
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument('setup', type=str, help='setup yaml file')
|
||||||
|
parser.add_argument('output', type=str, help='output path')
|
||||||
|
parser.add_argument('--device', type=str, help='compute device', default=None)
|
||||||
|
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
|
||||||
|
parser.add_argument('--test-features', type=str, help='path to features for testing', default=None)
|
||||||
|
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(4)
|
||||||
|
|
||||||
|
with open(args.setup, 'r') as f:
|
||||||
|
setup = yaml.load(f.read(), yaml.FullLoader)
|
||||||
|
|
||||||
|
checkpoint_prefix = 'checkpoint'
|
||||||
|
output_prefix = 'output'
|
||||||
|
setup_name = 'setup.yml'
|
||||||
|
output_file='out.txt'
|
||||||
|
|
||||||
|
|
||||||
|
# check model
|
||||||
|
if not 'name' in setup['model']:
|
||||||
|
print(f'warning: did not find model entry in setup, using default PitchPostFilter')
|
||||||
|
model_name = 'pitchpostfilter'
|
||||||
|
else:
|
||||||
|
model_name = setup['model']['name']
|
||||||
|
|
||||||
|
# prepare output folder
|
||||||
|
if os.path.exists(args.output):
|
||||||
|
print("warning: output folder exists")
|
||||||
|
|
||||||
|
reply = input('continue? (y/n): ')
|
||||||
|
while reply not in {'y', 'n'}:
|
||||||
|
reply = input('continue? (y/n): ')
|
||||||
|
|
||||||
|
if reply == 'n':
|
||||||
|
os._exit()
|
||||||
|
else:
|
||||||
|
os.makedirs(args.output, exist_ok=True)
|
||||||
|
|
||||||
|
checkpoint_dir = os.path.join(args.output, 'checkpoints')
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# add repo info to setup
|
||||||
|
if has_git:
|
||||||
|
working_dir = os.path.split(__file__)[0]
|
||||||
|
try:
|
||||||
|
repo = git.Repo(working_dir)
|
||||||
|
setup['repo'] = dict()
|
||||||
|
hash = repo.head.object.hexsha
|
||||||
|
urls = list(repo.remote().urls)
|
||||||
|
is_dirty = repo.is_dirty()
|
||||||
|
|
||||||
|
if is_dirty:
|
||||||
|
print("warning: repo is dirty")
|
||||||
|
|
||||||
|
setup['repo']['hash'] = hash
|
||||||
|
setup['repo']['urls'] = urls
|
||||||
|
setup['repo']['dirty'] = is_dirty
|
||||||
|
except:
|
||||||
|
has_git = False
|
||||||
|
|
||||||
|
# dump setup
|
||||||
|
with open(os.path.join(args.output, setup_name), 'w') as f:
|
||||||
|
yaml.dump(setup, f)
|
||||||
|
|
||||||
|
|
||||||
|
ref = None
|
||||||
|
# prepare inference test if wanted
|
||||||
|
inference_test = False
|
||||||
|
if type(args.test_features) != type(None):
|
||||||
|
test_features = load_lpcnet_features(args.test_features)
|
||||||
|
features = test_features['features']
|
||||||
|
periods = test_features['periods']
|
||||||
|
inference_folder = os.path.join(args.output, 'inference_test')
|
||||||
|
os.makedirs(inference_folder, exist_ok=True)
|
||||||
|
inference_test = True
|
||||||
|
|
||||||
|
|
||||||
|
# training parameters
|
||||||
|
batch_size = setup['training']['batch_size']
|
||||||
|
epochs = setup['training']['epochs']
|
||||||
|
lr = setup['training']['lr']
|
||||||
|
lr_decay_factor = setup['training']['lr_decay_factor']
|
||||||
|
lr_gen = lr * setup['training']['gen_lr_reduction']
|
||||||
|
lambda_feat = setup['training']['lambda_feat']
|
||||||
|
lambda_reg = setup['training']['lambda_reg']
|
||||||
|
adv_target = setup['training'].get('adv_target', 'target')
|
||||||
|
|
||||||
|
|
||||||
|
# load training dataset
|
||||||
|
data_config = setup['data']
|
||||||
|
data = LPCNetVocodingDataset(setup['dataset'], **data_config)
|
||||||
|
|
||||||
|
# load validation dataset if given
|
||||||
|
if 'validation_dataset' in setup:
|
||||||
|
validation_data = LPCNetVocodingDataset(setup['validation_dataset'], **data_config)
|
||||||
|
|
||||||
|
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)
|
||||||
|
|
||||||
|
run_validation = True
|
||||||
|
else:
|
||||||
|
run_validation = False
|
||||||
|
|
||||||
|
# create model
|
||||||
|
model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
|
||||||
|
|
||||||
|
|
||||||
|
# create discriminator
|
||||||
|
disc_name = setup['discriminator']['name']
|
||||||
|
disc = model_dict[disc_name](
|
||||||
|
*setup['discriminator']['args'], **setup['discriminator']['kwargs']
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# set compute device
|
||||||
|
if type(args.device) == type(None):
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
else:
|
||||||
|
device = torch.device(args.device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# dataloader
|
||||||
|
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
|
||||||
|
|
||||||
|
# optimizer is introduced to trainable parameters
|
||||||
|
parameters = [p for p in model.parameters() if p.requires_grad]
|
||||||
|
optimizer = torch.optim.Adam(parameters, lr=lr_gen)
|
||||||
|
|
||||||
|
# disc optimizer
|
||||||
|
parameters = [p for p in disc.parameters() if p.requires_grad]
|
||||||
|
optimizer_disc = torch.optim.Adam(parameters, lr=lr, betas=[0.5, 0.9])
|
||||||
|
|
||||||
|
# learning rate scheduler
|
||||||
|
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
|
||||||
|
|
||||||
|
if args.initial_checkpoint is not None:
|
||||||
|
print(f"loading state dict from {args.initial_checkpoint}...")
|
||||||
|
chkpt = torch.load(args.initial_checkpoint, map_location=device)
|
||||||
|
model.load_state_dict(chkpt['state_dict'])
|
||||||
|
|
||||||
|
if 'disc_state_dict' in chkpt:
|
||||||
|
print(f"loading discriminator state dict from {args.initial_checkpoint}...")
|
||||||
|
disc.load_state_dict(chkpt['disc_state_dict'])
|
||||||
|
|
||||||
|
if 'optimizer_state_dict' in chkpt:
|
||||||
|
print(f"loading optimizer state dict from {args.initial_checkpoint}...")
|
||||||
|
optimizer.load_state_dict(chkpt['optimizer_state_dict'])
|
||||||
|
|
||||||
|
if 'disc_optimizer_state_dict' in chkpt:
|
||||||
|
print(f"loading discriminator optimizer state dict from {args.initial_checkpoint}...")
|
||||||
|
optimizer_disc.load_state_dict(chkpt['disc_optimizer_state_dict'])
|
||||||
|
|
||||||
|
if 'scheduler_state_disc' in chkpt:
|
||||||
|
print(f"loading scheduler state dict from {args.initial_checkpoint}...")
|
||||||
|
scheduler.load_state_dict(chkpt['scheduler_state_dict'])
|
||||||
|
|
||||||
|
# if 'torch_rng_state' in chkpt:
|
||||||
|
# print(f"setting torch RNG state from {args.initial_checkpoint}...")
|
||||||
|
# torch.set_rng_state(chkpt['torch_rng_state'])
|
||||||
|
|
||||||
|
if 'numpy_rng_state' in chkpt:
|
||||||
|
print(f"setting numpy RNG state from {args.initial_checkpoint}...")
|
||||||
|
np.random.set_state(chkpt['numpy_rng_state'])
|
||||||
|
|
||||||
|
if 'python_rng_state' in chkpt:
|
||||||
|
print(f"setting Python RNG state from {args.initial_checkpoint}...")
|
||||||
|
random.setstate(chkpt['python_rng_state'])
|
||||||
|
|
||||||
|
# loss
|
||||||
|
w_l1 = setup['training']['loss']['w_l1']
|
||||||
|
w_lm = setup['training']['loss']['w_lm']
|
||||||
|
w_slm = setup['training']['loss']['w_slm']
|
||||||
|
w_sc = setup['training']['loss']['w_sc']
|
||||||
|
w_logmel = setup['training']['loss']['w_logmel']
|
||||||
|
w_wsc = setup['training']['loss']['w_wsc']
|
||||||
|
w_xcorr = setup['training']['loss']['w_xcorr']
|
||||||
|
w_sxcorr = setup['training']['loss']['w_sxcorr']
|
||||||
|
w_l2 = setup['training']['loss']['w_l2']
|
||||||
|
|
||||||
|
w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
|
||||||
|
|
||||||
|
stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
|
||||||
|
logmelloss = MRLogMelLoss().to(device)
|
||||||
|
|
||||||
|
def xcorr_loss(y_true, y_pred):
|
||||||
|
dims = list(range(1, len(y_true.shape)))
|
||||||
|
|
||||||
|
loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
|
||||||
|
|
||||||
|
return torch.mean(loss)
|
||||||
|
|
||||||
|
def td_l2_norm(y_true, y_pred):
|
||||||
|
dims = list(range(1, len(y_true.shape)))
|
||||||
|
|
||||||
|
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
|
||||||
|
|
||||||
|
return loss.mean()
|
||||||
|
|
||||||
|
def td_l1(y_true, y_pred, pow=0):
|
||||||
|
dims = list(range(1, len(y_true.shape)))
|
||||||
|
tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
|
||||||
|
|
||||||
|
return torch.mean(tmp)
|
||||||
|
|
||||||
|
def criterion(x, y):
|
||||||
|
|
||||||
|
return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
|
||||||
|
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
|
||||||
|
|
||||||
|
|
||||||
|
# model checkpoint
|
||||||
|
checkpoint = {
|
||||||
|
'setup' : setup,
|
||||||
|
'state_dict' : model.state_dict(),
|
||||||
|
'loss' : -1
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if not args.no_redirect:
|
||||||
|
print(f"re-directing output to {os.path.join(args.output, output_file)}")
|
||||||
|
sys.stdout = open(os.path.join(args.output, output_file), "w")
|
||||||
|
|
||||||
|
|
||||||
|
print("summary:")
|
||||||
|
|
||||||
|
print(f"generator: {count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
|
||||||
|
if hasattr(model, 'flop_count'):
|
||||||
|
print(f"generator: {model.flop_count(16000) / 1e6:5.3f} MFLOPS")
|
||||||
|
print(f"discriminator: {count_parameters(disc.cpu()) / 1e6:5.3f} M parameters")
|
||||||
|
|
||||||
|
if ref is not None:
|
||||||
|
noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16)
|
||||||
|
initial_mos = pesq.pesq(16000, ref, noisy, mode='wb')
|
||||||
|
print(f"initial MOS (PESQ): {initial_mos}")
|
||||||
|
|
||||||
|
best_loss = 1e9
|
||||||
|
log_interval = 10
|
||||||
|
|
||||||
|
|
||||||
|
m_r = 0
|
||||||
|
m_f = 0
|
||||||
|
s_r = 1
|
||||||
|
s_f = 1
|
||||||
|
|
||||||
|
def optimizer_to(optim, device):
|
||||||
|
for param in optim.state.values():
|
||||||
|
if isinstance(param, torch.Tensor):
|
||||||
|
param.data = param.data.to(device)
|
||||||
|
if param._grad is not None:
|
||||||
|
param._grad.data = param._grad.data.to(device)
|
||||||
|
elif isinstance(param, dict):
|
||||||
|
for subparam in param.values():
|
||||||
|
if isinstance(subparam, torch.Tensor):
|
||||||
|
subparam.data = subparam.data.to(device)
|
||||||
|
if subparam._grad is not None:
|
||||||
|
subparam._grad.data = subparam._grad.data.to(device)
|
||||||
|
|
||||||
|
optimizer_to(optimizer, device)
|
||||||
|
optimizer_to(optimizer_disc, device)
|
||||||
|
|
||||||
|
|
||||||
|
for ep in range(1, epochs + 1):
|
||||||
|
print(f"training epoch {ep}...")
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
disc.to(device)
|
||||||
|
model.train()
|
||||||
|
disc.train()
|
||||||
|
|
||||||
|
running_disc_loss = 0
|
||||||
|
running_adv_loss = 0
|
||||||
|
running_feature_loss = 0
|
||||||
|
running_reg_loss = 0
|
||||||
|
|
||||||
|
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||||
|
for i, batch in enumerate(tepoch):
|
||||||
|
|
||||||
|
# set gradients to zero
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# push batch to device
|
||||||
|
for key in batch:
|
||||||
|
batch[key] = batch[key].to(device)
|
||||||
|
|
||||||
|
target = batch['target'].to(device)
|
||||||
|
disc_target = batch[adv_target].to(device)
|
||||||
|
|
||||||
|
# calculate model output
|
||||||
|
output = model(batch['features'], batch['periods'])
|
||||||
|
|
||||||
|
# discriminator update
|
||||||
|
scores_gen = disc(output.detach())
|
||||||
|
scores_real = disc(disc_target.unsqueeze(1))
|
||||||
|
|
||||||
|
disc_loss = 0
|
||||||
|
for scale in scores_gen:
|
||||||
|
disc_loss += ((scale[-1]) ** 2).mean()
|
||||||
|
m_f = 0.9 * m_f + 0.1 * scale[-1].detach().mean().cpu().item()
|
||||||
|
s_f = 0.9 * s_f + 0.1 * scale[-1].detach().std().cpu().item()
|
||||||
|
|
||||||
|
for scale in scores_real:
|
||||||
|
disc_loss += ((1 - scale[-1]) ** 2).mean()
|
||||||
|
m_r = 0.9 * m_r + 0.1 * scale[-1].detach().mean().cpu().item()
|
||||||
|
s_r = 0.9 * s_r + 0.1 * scale[-1].detach().std().cpu().item()
|
||||||
|
|
||||||
|
disc_loss = 0.5 * disc_loss / len(scores_gen)
|
||||||
|
winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) )
|
||||||
|
|
||||||
|
disc.zero_grad()
|
||||||
|
disc_loss.backward()
|
||||||
|
optimizer_disc.step()
|
||||||
|
|
||||||
|
# generator update
|
||||||
|
scores_gen = disc(output)
|
||||||
|
|
||||||
|
|
||||||
|
# calculate loss
|
||||||
|
loss_reg = criterion(output.squeeze(1), target)
|
||||||
|
|
||||||
|
num_discs = len(scores_gen)
|
||||||
|
loss_gen = 0
|
||||||
|
for scale in scores_gen:
|
||||||
|
loss_gen += ((1 - scale[-1]) ** 2).mean() / num_discs
|
||||||
|
|
||||||
|
loss_feat = 0
|
||||||
|
for k in range(num_discs):
|
||||||
|
num_layers = len(scores_gen[k]) - 1
|
||||||
|
f = 4 / num_discs / num_layers
|
||||||
|
for l in range(num_layers):
|
||||||
|
loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach())
|
||||||
|
|
||||||
|
model.zero_grad()
|
||||||
|
|
||||||
|
(loss_gen + lambda_feat * loss_feat + lambda_reg * loss_reg).backward()
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_adv_loss += loss_gen.detach().cpu().item()
|
||||||
|
running_disc_loss += disc_loss.detach().cpu().item()
|
||||||
|
running_feature_loss += lambda_feat * loss_feat.detach().cpu().item()
|
||||||
|
running_reg_loss += lambda_reg * loss_reg.detach().cpu().item()
|
||||||
|
|
||||||
|
# update status bar
|
||||||
|
if i % log_interval == 0:
|
||||||
|
tepoch.set_postfix(adv_loss=f"{running_adv_loss/(i + 1):8.7f}",
|
||||||
|
disc_loss=f"{running_disc_loss/(i + 1):8.7f}",
|
||||||
|
feat_loss=f"{running_feature_loss/(i + 1):8.7f}",
|
||||||
|
reg_loss=f"{running_reg_loss/(i + 1):8.7f}",
|
||||||
|
wc=f"{100*winning_chance:5.2f}%")
|
||||||
|
|
||||||
|
|
||||||
|
# save checkpoint
|
||||||
|
checkpoint['state_dict'] = model.state_dict()
|
||||||
|
checkpoint['disc_state_dict'] = disc.state_dict()
|
||||||
|
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
|
||||||
|
checkpoint['disc_optimizer_state_dict'] = optimizer_disc.state_dict()
|
||||||
|
checkpoint['scheduler_state_dict'] = scheduler.state_dict()
|
||||||
|
checkpoint['torch_rng_state'] = torch.get_rng_state()
|
||||||
|
checkpoint['numpy_rng_state'] = np.random.get_state()
|
||||||
|
checkpoint['python_rng_state'] = random.getstate()
|
||||||
|
checkpoint['adv_loss'] = running_adv_loss/(i + 1)
|
||||||
|
checkpoint['disc_loss'] = running_disc_loss/(i + 1)
|
||||||
|
checkpoint['feature_loss'] = running_feature_loss/(i + 1)
|
||||||
|
checkpoint['reg_loss'] = running_reg_loss/(i + 1)
|
||||||
|
|
||||||
|
|
||||||
|
if inference_test:
|
||||||
|
print("running inference test...")
|
||||||
|
out = model.process(features, periods).cpu().numpy()
|
||||||
|
wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
|
||||||
|
if ref is not None:
|
||||||
|
mos = pesq.pesq(16000, ref, out, mode='wb')
|
||||||
|
print(f"MOS (PESQ): {mos}")
|
||||||
|
|
||||||
|
|
||||||
|
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
|
||||||
|
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
|
||||||
|
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
print('Done')
|
|
@ -1,30 +1,2 @@
|
||||||
"""
|
from .silk_enhancement_set import SilkEnhancementSet
|
||||||
/* Copyright (c) 2023 Amazon
|
from .lpcnet_vocoding_dataset import LPCNetVocodingDataset
|
||||||
Written by Jan Buethe */
|
|
||||||
/*
|
|
||||||
Redistribution and use in source and binary forms, with or without
|
|
||||||
modification, are permitted provided that the following conditions
|
|
||||||
are met:
|
|
||||||
|
|
||||||
- Redistributions of source code must retain the above copyright
|
|
||||||
notice, this list of conditions and the following disclaimer.
|
|
||||||
|
|
||||||
- Redistributions in binary form must reproduce the above copyright
|
|
||||||
notice, this list of conditions and the following disclaimer in the
|
|
||||||
documentation and/or other materials provided with the distribution.
|
|
||||||
|
|
||||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
||||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
||||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
||||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
|
||||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
|
||||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
|
||||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
|
||||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
|
||||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
|
||||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
|
||||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
||||||
*/
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .silk_enhancement_set import SilkEnhancementSet
|
|
225
dnn/torch/osce/data/lpcnet_vocoding_dataset.py
Normal file
225
dnn/torch/osce/data/lpcnet_vocoding_dataset.py
Normal file
|
@ -0,0 +1,225 @@
|
||||||
|
"""
|
||||||
|
/* Copyright (c) 2023 Amazon
|
||||||
|
Written by Jan Buethe */
|
||||||
|
/*
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions
|
||||||
|
are met:
|
||||||
|
|
||||||
|
- Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
- Redistributions in binary form must reproduce the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer in the
|
||||||
|
documentation and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||||
|
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||||
|
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||||
|
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*/
|
||||||
|
"""
|
||||||
|
|
||||||
|
""" Dataset for LPCNet training """
|
||||||
|
import os
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
scale = 255.0/32768.0
|
||||||
|
scale_1 = 32768.0/255.0
|
||||||
|
def ulaw2lin(u):
|
||||||
|
u = u - 128
|
||||||
|
s = np.sign(u)
|
||||||
|
u = np.abs(u)
|
||||||
|
return s*scale_1*(np.exp(u/128.*np.log(256))-1)
|
||||||
|
|
||||||
|
|
||||||
|
def lin2ulaw(x):
|
||||||
|
s = np.sign(x)
|
||||||
|
x = np.abs(x)
|
||||||
|
u = (s*(128*np.log(1+scale*x)/np.log(256)))
|
||||||
|
u = np.clip(128 + np.round(u), 0, 255)
|
||||||
|
return u
|
||||||
|
|
||||||
|
|
||||||
|
def run_lpc(signal, lpcs, frame_length=160):
|
||||||
|
num_frames, lpc_order = lpcs.shape
|
||||||
|
|
||||||
|
prediction = np.concatenate(
|
||||||
|
[- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]
|
||||||
|
)
|
||||||
|
error = signal[lpc_order :] - prediction
|
||||||
|
|
||||||
|
return prediction, error
|
||||||
|
|
||||||
|
class LPCNetVocodingDataset(Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
path_to_dataset,
|
||||||
|
features=['cepstrum', 'periods', 'pitch_corr'],
|
||||||
|
target='signal',
|
||||||
|
frames_per_sample=100,
|
||||||
|
feature_history=0,
|
||||||
|
feature_lookahead=0,
|
||||||
|
lpc_gamma=1):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# load dataset info
|
||||||
|
self.path_to_dataset = path_to_dataset
|
||||||
|
with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f:
|
||||||
|
dataset = yaml.load(f, yaml.FullLoader)
|
||||||
|
|
||||||
|
# dataset version
|
||||||
|
self.version = dataset['version']
|
||||||
|
if self.version == 1:
|
||||||
|
self.getitem = self.getitem_v1
|
||||||
|
elif self.version == 2:
|
||||||
|
self.getitem = self.getitem_v2
|
||||||
|
else:
|
||||||
|
raise ValueError(f"dataset version {self.version} unknown")
|
||||||
|
|
||||||
|
# features
|
||||||
|
self.feature_history = feature_history
|
||||||
|
self.feature_lookahead = feature_lookahead
|
||||||
|
self.frame_offset = 2 + self.feature_history
|
||||||
|
self.frames_per_sample = frames_per_sample
|
||||||
|
self.input_features = features
|
||||||
|
self.feature_frame_layout = dataset['feature_frame_layout']
|
||||||
|
self.lpc_gamma = lpc_gamma
|
||||||
|
|
||||||
|
# load feature file
|
||||||
|
self.feature_file = os.path.join(path_to_dataset, dataset['feature_file'])
|
||||||
|
self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype'])
|
||||||
|
self.feature_frame_length = dataset['feature_frame_length']
|
||||||
|
|
||||||
|
assert len(self.features) % self.feature_frame_length == 0
|
||||||
|
self.features = self.features.reshape((-1, self.feature_frame_length))
|
||||||
|
|
||||||
|
# derive number of samples is dataset
|
||||||
|
self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1 - 2) // self.frames_per_sample
|
||||||
|
|
||||||
|
# signals
|
||||||
|
self.frame_length = dataset['frame_length']
|
||||||
|
self.signal_frame_layout = dataset['signal_frame_layout']
|
||||||
|
self.target = target
|
||||||
|
|
||||||
|
# load signals
|
||||||
|
self.signal_file = os.path.join(path_to_dataset, dataset['signal_file'])
|
||||||
|
self.signals = np.memmap(self.signal_file, dtype=dataset['signal_dtype'])
|
||||||
|
self.signal_frame_length = dataset['signal_frame_length']
|
||||||
|
self.signals = self.signals.reshape((-1, self.signal_frame_length))
|
||||||
|
assert len(self.signals) == len(self.features) * self.frame_length
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.getitem(index)
|
||||||
|
|
||||||
|
def getitem_v2(self, index):
|
||||||
|
sample = dict()
|
||||||
|
|
||||||
|
# extract features
|
||||||
|
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
|
||||||
|
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
|
||||||
|
|
||||||
|
for feature in self.input_features:
|
||||||
|
feature_start, feature_stop = self.feature_frame_layout[feature]
|
||||||
|
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
|
||||||
|
|
||||||
|
# convert periods
|
||||||
|
if 'periods' in self.input_features:
|
||||||
|
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
|
||||||
|
|
||||||
|
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
|
||||||
|
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
|
||||||
|
|
||||||
|
# last_signal and signal are always expected to be there
|
||||||
|
sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
|
||||||
|
sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']]
|
||||||
|
|
||||||
|
# calculate prediction and error if lpc coefficients present and prediction not given
|
||||||
|
if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout:
|
||||||
|
# lpc coefficients with one frame lookahead
|
||||||
|
# frame positions (start one frame early for past excitation)
|
||||||
|
frame_start = self.frame_offset + self.frames_per_sample * index - 1
|
||||||
|
frame_stop = self.frame_offset + self.frames_per_sample * (index + 1)
|
||||||
|
|
||||||
|
# feature positions
|
||||||
|
lpc_start, lpc_stop = self.feature_frame_layout['lpc']
|
||||||
|
lpc_order = lpc_stop - lpc_start
|
||||||
|
lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop]
|
||||||
|
|
||||||
|
# LPC weighting
|
||||||
|
lpc_order = lpc_stop - lpc_start
|
||||||
|
weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)])
|
||||||
|
lpcs = lpcs * weights
|
||||||
|
|
||||||
|
# signal position (lpc_order samples as history)
|
||||||
|
signal_start = frame_start * self.frame_length - lpc_order + 1
|
||||||
|
signal_stop = frame_stop * self.frame_length + 1
|
||||||
|
noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
|
||||||
|
clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']]
|
||||||
|
|
||||||
|
noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length)
|
||||||
|
|
||||||
|
# extract signals
|
||||||
|
offset = self.frame_length
|
||||||
|
sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample]
|
||||||
|
sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample]
|
||||||
|
# calculate error between real signal and noisy prediction
|
||||||
|
|
||||||
|
|
||||||
|
sample['error'] = sample['signal'] - sample['prediction']
|
||||||
|
|
||||||
|
|
||||||
|
# concatenate features
|
||||||
|
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
|
||||||
|
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
|
||||||
|
target = torch.FloatTensor(sample[self.target]) / 2**15
|
||||||
|
periods = torch.LongTensor(sample['periods'])
|
||||||
|
|
||||||
|
return {'features' : features, 'periods' : periods, 'target' : target}
|
||||||
|
|
||||||
|
def getitem_v1(self, index):
|
||||||
|
sample = dict()
|
||||||
|
|
||||||
|
# extract features
|
||||||
|
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
|
||||||
|
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
|
||||||
|
|
||||||
|
for feature in self.input_features:
|
||||||
|
feature_start, feature_stop = self.feature_frame_layout[feature]
|
||||||
|
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
|
||||||
|
|
||||||
|
# convert periods
|
||||||
|
if 'periods' in self.input_features:
|
||||||
|
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
|
||||||
|
|
||||||
|
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
|
||||||
|
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
|
||||||
|
|
||||||
|
# last_signal and signal are always expected to be there
|
||||||
|
for signal_name, index in self.signal_frame_layout.items():
|
||||||
|
sample[signal_name] = self.signals[signal_start : signal_stop, index]
|
||||||
|
|
||||||
|
# concatenate features
|
||||||
|
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
|
||||||
|
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
|
||||||
|
signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1)
|
||||||
|
target = torch.LongTensor(sample[self.target])
|
||||||
|
periods = torch.LongTensor(sample['periods'])
|
||||||
|
|
||||||
|
return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.dataset_length
|
|
@ -50,7 +50,7 @@ class SilkEnhancementSet(Dataset):
|
||||||
noisy_spec_scale='opus',
|
noisy_spec_scale='opus',
|
||||||
noisy_apply_dct=True,
|
noisy_apply_dct=True,
|
||||||
add_offset=False,
|
add_offset=False,
|
||||||
add_double_lag_acorr=False
|
add_double_lag_acorr=False,
|
||||||
):
|
):
|
||||||
|
|
||||||
assert frames_per_sample % 4 == 0
|
assert frames_per_sample % 4 == 0
|
||||||
|
@ -75,8 +75,9 @@ class SilkEnhancementSet(Dataset):
|
||||||
self.num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32)
|
self.num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32)
|
||||||
self.offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32)
|
self.offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32)
|
||||||
|
|
||||||
self.clean_signal = np.fromfile(os.path.join(path, 'clean_hp.s16'), dtype=np.int16)
|
self.clean_signal_hp = np.fromfile(os.path.join(path, 'clean_hp.s16'), dtype=np.int16)
|
||||||
self.coded_signal = np.fromfile(os.path.join(path, 'coded.s16'), dtype=np.int16)
|
self.clean_signal = np.fromfile(os.path.join(path, 'clean.s16'), dtype=np.int16)
|
||||||
|
self.coded_signal = np.fromfile(os.path.join(path, 'coded.s16'), dtype=np.int16)
|
||||||
|
|
||||||
self.create_features = silk_feature_factory(no_pitch_value,
|
self.create_features = silk_feature_factory(no_pitch_value,
|
||||||
acorr_radius,
|
acorr_radius,
|
||||||
|
@ -92,7 +93,7 @@ class SilkEnhancementSet(Dataset):
|
||||||
# discard some frames to have enough signal history
|
# discard some frames to have enough signal history
|
||||||
self.skip_frames = 4 * ((skip + self.history_len + 319) // 320 + 2)
|
self.skip_frames = 4 * ((skip + self.history_len + 319) // 320 + 2)
|
||||||
|
|
||||||
num_frames = self.clean_signal.shape[0] // 80 - self.skip_frames
|
num_frames = self.clean_signal_hp.shape[0] // 80 - self.skip_frames
|
||||||
|
|
||||||
self.len = num_frames // frames_per_sample
|
self.len = num_frames // frames_per_sample
|
||||||
|
|
||||||
|
@ -107,8 +108,9 @@ class SilkEnhancementSet(Dataset):
|
||||||
signal_start = frame_start * self.frame_size - self.skip
|
signal_start = frame_start * self.frame_size - self.skip
|
||||||
signal_stop = frame_stop * self.frame_size - self.skip
|
signal_stop = frame_stop * self.frame_size - self.skip
|
||||||
|
|
||||||
clean_signal = self.clean_signal[signal_start : signal_stop].astype(np.float32) / 2**15
|
clean_signal_hp = self.clean_signal_hp[signal_start : signal_stop].astype(np.float32) / 2**15
|
||||||
coded_signal = self.coded_signal[signal_start : signal_stop].astype(np.float32) / 2**15
|
clean_signal = self.clean_signal[signal_start : signal_stop].astype(np.float32) / 2**15
|
||||||
|
coded_signal = self.coded_signal[signal_start : signal_stop].astype(np.float32) / 2**15
|
||||||
|
|
||||||
coded_signal_history = self.coded_signal[signal_start - self.history_len : signal_start].astype(np.float32) / 2**15
|
coded_signal_history = self.coded_signal[signal_start - self.history_len : signal_start].astype(np.float32) / 2**15
|
||||||
|
|
||||||
|
@ -124,6 +126,7 @@ class SilkEnhancementSet(Dataset):
|
||||||
|
|
||||||
if self.preemph > 0:
|
if self.preemph > 0:
|
||||||
clean_signal[1:] -= self.preemph * clean_signal[: -1]
|
clean_signal[1:] -= self.preemph * clean_signal[: -1]
|
||||||
|
clean_signal_hp[1:] -= self.preemph * clean_signal_hp[: -1]
|
||||||
coded_signal[1:] -= self.preemph * coded_signal[: -1]
|
coded_signal[1:] -= self.preemph * coded_signal[: -1]
|
||||||
|
|
||||||
num_bits = np.repeat(self.num_bits[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
|
num_bits = np.repeat(self.num_bits[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
|
||||||
|
@ -132,9 +135,10 @@ class SilkEnhancementSet(Dataset):
|
||||||
numbits = np.concatenate((num_bits, num_bits_smooth), axis=-1)
|
numbits = np.concatenate((num_bits, num_bits_smooth), axis=-1)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'features' : features,
|
'features' : features,
|
||||||
'periods' : periods.astype(np.int64),
|
'periods' : periods.astype(np.int64),
|
||||||
'target' : clean_signal.astype(np.float32),
|
'target_orig' : clean_signal.astype(np.float32),
|
||||||
'signals' : coded_signal.reshape(-1, 1).astype(np.float32),
|
'target' : clean_signal_hp.astype(np.float32),
|
||||||
'numbits' : numbits.astype(np.float32)
|
'signals' : coded_signal.reshape(-1, 1).astype(np.float32),
|
||||||
|
'numbits' : numbits.astype(np.float32)
|
||||||
}
|
}
|
||||||
|
|
101
dnn/torch/osce/engine/vocoder_engine.py
Normal file
101
dnn/torch/osce/engine/vocoder_engine.py
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
running_loss = 0
|
||||||
|
previous_running_loss = 0
|
||||||
|
|
||||||
|
|
||||||
|
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||||
|
|
||||||
|
for i, batch in enumerate(tepoch):
|
||||||
|
|
||||||
|
# set gradients to zero
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
|
||||||
|
# push batch to device
|
||||||
|
for key in batch:
|
||||||
|
batch[key] = batch[key].to(device)
|
||||||
|
|
||||||
|
target = batch['target']
|
||||||
|
|
||||||
|
# calculate model output
|
||||||
|
output = model(batch['features'], batch['periods'])
|
||||||
|
|
||||||
|
# calculate loss
|
||||||
|
if isinstance(output, list):
|
||||||
|
loss = torch.zeros(1, device=device)
|
||||||
|
for y in output:
|
||||||
|
loss = loss + criterion(target, y.squeeze(1))
|
||||||
|
loss = loss / len(output)
|
||||||
|
else:
|
||||||
|
loss = criterion(target, output.squeeze(1))
|
||||||
|
|
||||||
|
# calculate gradients
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# update weights
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# update learning rate
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
# update running loss
|
||||||
|
running_loss += float(loss.cpu())
|
||||||
|
|
||||||
|
# update status bar
|
||||||
|
if i % log_interval == 0:
|
||||||
|
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
|
||||||
|
previous_running_loss = running_loss
|
||||||
|
|
||||||
|
|
||||||
|
running_loss /= len(dataloader)
|
||||||
|
|
||||||
|
return running_loss
|
||||||
|
|
||||||
|
def evaluate(model, criterion, dataloader, device, log_interval=10):
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
running_loss = 0
|
||||||
|
previous_running_loss = 0
|
||||||
|
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||||
|
|
||||||
|
for i, batch in enumerate(tepoch):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# push batch to device
|
||||||
|
for key in batch:
|
||||||
|
batch[key] = batch[key].to(device)
|
||||||
|
|
||||||
|
target = batch['target']
|
||||||
|
|
||||||
|
# calculate model output
|
||||||
|
output = model(batch['features'], batch['periods'])
|
||||||
|
|
||||||
|
# calculate loss
|
||||||
|
loss = criterion(target, output.squeeze(1))
|
||||||
|
|
||||||
|
# update running loss
|
||||||
|
running_loss += float(loss.cpu())
|
||||||
|
|
||||||
|
# update status bar
|
||||||
|
if i % log_interval == 0:
|
||||||
|
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
|
||||||
|
previous_running_loss = running_loss
|
||||||
|
|
||||||
|
|
||||||
|
running_loss /= len(dataloader)
|
||||||
|
|
||||||
|
return running_loss
|
|
@ -27,6 +27,36 @@
|
||||||
*/
|
*/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
/* Copyright (c) 2023 Amazon
|
||||||
|
Written by Jan Buethe */
|
||||||
|
/*
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions
|
||||||
|
are met:
|
||||||
|
|
||||||
|
- Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
- Redistributions in binary form must reproduce the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer in the
|
||||||
|
documentation and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||||
|
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||||
|
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||||
|
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*/
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -36,12 +66,19 @@ from utils.templates import setup_dict
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument('name', type=str, help='name of default setup file')
|
parser.add_argument('name', type=str, help='name of default setup file')
|
||||||
parser.add_argument('--model', choices=['lace', 'nolace'], help='model name', default='lace')
|
parser.add_argument('--model', choices=['lace', 'nolace', 'lavoce'], help='model name', default='lace')
|
||||||
|
parser.add_argument('--adversarial', action='store_true', help='setup for adversarial training')
|
||||||
parser.add_argument('--path2dataset', type=str, help='dataset path', default=None)
|
parser.add_argument('--path2dataset', type=str, help='dataset path', default=None)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
setup = setup_dict[args.model]
|
key = args.model + "_adv" if args.adversarial else args.model
|
||||||
|
|
||||||
|
try:
|
||||||
|
setup = setup_dict[key]
|
||||||
|
except KeyError:
|
||||||
|
print("setup not found, adversarial training possibly not specified for model")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
# update dataset if given
|
# update dataset if given
|
||||||
if type(args.path2dataset) != type(None):
|
if type(args.path2dataset) != type(None):
|
||||||
|
|
|
@ -29,10 +29,12 @@
|
||||||
|
|
||||||
from .lace import LACE
|
from .lace import LACE
|
||||||
from .no_lace import NoLACE
|
from .no_lace import NoLACE
|
||||||
|
from .lavoce import LaVoce
|
||||||
|
from .fd_discriminator import TFDMultiResolutionDiscriminator as FDMResDisc
|
||||||
|
|
||||||
model_dict = {
|
model_dict = {
|
||||||
'lace': LACE,
|
'lace': LACE,
|
||||||
'nolace': NoLACE
|
'nolace': NoLACE,
|
||||||
|
'lavoce': LaVoce,
|
||||||
|
'fdmresdisc': FDMResDisc,
|
||||||
}
|
}
|
||||||
|
|
974
dnn/torch/osce/models/fd_discriminator.py
Normal file
974
dnn/torch/osce/models/fd_discriminator.py
Normal file
|
@ -0,0 +1,974 @@
|
||||||
|
"""
|
||||||
|
/* Copyright (c) 2023 Amazon
|
||||||
|
Written by Jan Buethe */
|
||||||
|
/*
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions
|
||||||
|
are met:
|
||||||
|
|
||||||
|
- Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
- Redistributions in binary form must reproduce the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer in the
|
||||||
|
documentation and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||||
|
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||||
|
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||||
|
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*/
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math as m
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn.utils import weight_norm, spectral_norm
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
from utils.spec import gen_filterbank
|
||||||
|
|
||||||
|
# auxiliary functions
|
||||||
|
|
||||||
|
def remove_all_weight_norms(module):
|
||||||
|
for m in module.modules():
|
||||||
|
if hasattr(m, 'weight_v'):
|
||||||
|
nn.utils.remove_weight_norm(m)
|
||||||
|
|
||||||
|
|
||||||
|
def create_smoothing_kernel(h, w, gamma=1.5):
|
||||||
|
|
||||||
|
ch = h / 2 - 0.5
|
||||||
|
cw = w / 2 - 0.5
|
||||||
|
|
||||||
|
sh = gamma * ch
|
||||||
|
sw = gamma * cw
|
||||||
|
|
||||||
|
vx = ((torch.arange(h) - ch) / sh) ** 2
|
||||||
|
vy = ((torch.arange(w) - cw) / sw) ** 2
|
||||||
|
vals = vx.view(-1, 1) + vy.view(1, -1)
|
||||||
|
kernel = torch.exp(- vals)
|
||||||
|
kernel = kernel / kernel.sum()
|
||||||
|
|
||||||
|
return kernel
|
||||||
|
|
||||||
|
|
||||||
|
def create_kernel(h, w, sh, sw):
|
||||||
|
# proto kernel gives disjoint partition of 1
|
||||||
|
proto_kernel = torch.ones((sh, sw))
|
||||||
|
|
||||||
|
# create smoothing kernel eta
|
||||||
|
h_eta, w_eta = h - sh + 1, w - sw + 1
|
||||||
|
assert h_eta > 0 and w_eta > 0
|
||||||
|
eta = create_smoothing_kernel(h_eta, w_eta).view(1, 1, h_eta, w_eta)
|
||||||
|
|
||||||
|
kernel0 = F.pad(proto_kernel, [w_eta - 1, w_eta - 1, h_eta - 1, h_eta - 1]).unsqueeze(0).unsqueeze(0)
|
||||||
|
kernel = F.conv2d(kernel0, eta)
|
||||||
|
|
||||||
|
return kernel
|
||||||
|
|
||||||
|
# positional embeddings
|
||||||
|
class FrequencyPositionalEmbedding(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
N = x.size(2)
|
||||||
|
args = torch.arange(0, N, dtype=x.dtype, device=x.device) * torch.pi * 2 / N
|
||||||
|
cos = torch.cos(args).reshape(1, 1, -1, 1)
|
||||||
|
sin = torch.sin(args).reshape(1, 1, -1, 1)
|
||||||
|
zeros = torch.zeros_like(x[:, 0:1, :, :])
|
||||||
|
|
||||||
|
y = torch.cat((x, zeros + sin, zeros + cos), dim=1)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalEmbedding2D(nn.Module):
|
||||||
|
def __init__(self, d=5):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.d = d
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
N = x.size(2)
|
||||||
|
M = x.size(3)
|
||||||
|
|
||||||
|
h_args = torch.arange(0, N, dtype=x.dtype, device=x.device).reshape(1, 1, -1, 1)
|
||||||
|
w_args = torch.arange(0, M, dtype=x.dtype, device=x.device).reshape(1, 1, 1, -1)
|
||||||
|
coeffs = (10000 ** (-2 * torch.arange(0, self.d, dtype=x.dtype, device=x.device) / self.d)).reshape(1, -1, 1, 1)
|
||||||
|
|
||||||
|
h_sin = torch.sin(coeffs * h_args)
|
||||||
|
h_cos = torch.sin(coeffs * h_args)
|
||||||
|
w_sin = torch.sin(coeffs * w_args)
|
||||||
|
w_cos = torch.sin(coeffs * w_args)
|
||||||
|
|
||||||
|
zeros = torch.zeros_like(x[:, 0:1, :, :])
|
||||||
|
|
||||||
|
y = torch.cat((x, zeros + h_sin, zeros + h_cos, zeros + w_sin, zeros + w_cos), dim=1)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
# spectral discriminator base class
|
||||||
|
class SpecDiscriminatorBase(nn.Module):
|
||||||
|
RECEPTIVE_FIELD_MAX_WIDTH=10000
|
||||||
|
def __init__(self,
|
||||||
|
layers,
|
||||||
|
resolution,
|
||||||
|
fs=16000,
|
||||||
|
freq_roi=[50, 7000],
|
||||||
|
noise_gain=1e-3,
|
||||||
|
fmap_start_index=0
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(layers)
|
||||||
|
self.resolution = resolution
|
||||||
|
self.fs = fs
|
||||||
|
self.noise_gain = noise_gain
|
||||||
|
self.fmap_start_index = fmap_start_index
|
||||||
|
|
||||||
|
if fmap_start_index >= len(layers):
|
||||||
|
raise ValueError(f'fmap_start_index is larger than number of layers')
|
||||||
|
|
||||||
|
# filter bank for noise shaping
|
||||||
|
n_fft = resolution[0]
|
||||||
|
|
||||||
|
self.filterbank = nn.Parameter(
|
||||||
|
gen_filterbank(n_fft // 2, fs, keep_size=True),
|
||||||
|
requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# roi bins
|
||||||
|
f_step = fs / n_fft
|
||||||
|
self.start_bin = int(m.ceil(freq_roi[0] / f_step - 0.01))
|
||||||
|
self.stop_bin = min(int(m.floor(freq_roi[1] / f_step + 0.01)), n_fft//2 + 1)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
# determine receptive field size, offsets and strides
|
||||||
|
|
||||||
|
hw = 1000
|
||||||
|
while True:
|
||||||
|
x = torch.zeros((1, hw, hw))
|
||||||
|
with torch.no_grad():
|
||||||
|
y = self.run_layer_stack(x)[-1]
|
||||||
|
|
||||||
|
pos0 = [y.size(-2) // 2, y.size(-1) // 2]
|
||||||
|
pos1 = [t + 1 for t in pos0]
|
||||||
|
|
||||||
|
hs0, ws0 = self._receptive_field((hw, hw), pos0)
|
||||||
|
hs1, ws1 = self._receptive_field((hw, hw), pos1)
|
||||||
|
|
||||||
|
h0 = hs0[1] - hs0[0] + 1
|
||||||
|
h1 = hs1[1] - hs1[0] + 1
|
||||||
|
w0 = ws0[1] - ws0[0] + 1
|
||||||
|
w1 = ws1[1] - ws1[0] + 1
|
||||||
|
|
||||||
|
if h0 != h1 or w0 != w1:
|
||||||
|
hw = 2 * hw
|
||||||
|
else:
|
||||||
|
|
||||||
|
# strides
|
||||||
|
sh = hs1[0] - hs0[0]
|
||||||
|
sw = ws1[0] - ws0[0]
|
||||||
|
|
||||||
|
if sh == 0 or sw == 0: continue
|
||||||
|
|
||||||
|
# offsets
|
||||||
|
oh = hs0[0] - sh * pos0[0]
|
||||||
|
ow = ws0[0] - sw * pos0[1]
|
||||||
|
|
||||||
|
# overlap factor
|
||||||
|
overlap = w0 / sw + h0 / sh
|
||||||
|
|
||||||
|
#print(f"{w0=} {h0=} {sw=} {sh=} {overlap=}")
|
||||||
|
self.receptive_field_params = {'width': [sw, ow, w0], 'height': [sh, oh, h0], 'overlap': overlap}
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
if hw > self.RECEPTIVE_FIELD_MAX_WIDTH:
|
||||||
|
print("warning: exceeded max size while trying to determine receptive field")
|
||||||
|
|
||||||
|
# create transposed convolutional kernel
|
||||||
|
#self.tconv_kernel = nn.Parameter(create_kernel(h0, w0, sw, sw), requires_grad=False)
|
||||||
|
|
||||||
|
def run_layer_stack(self, spec):
|
||||||
|
|
||||||
|
output = []
|
||||||
|
|
||||||
|
x = spec.unsqueeze(1)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
output.append(x)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
""" returns array with feature maps and final score at index -1 """
|
||||||
|
|
||||||
|
output = []
|
||||||
|
|
||||||
|
x = self.spectrogram(x)
|
||||||
|
|
||||||
|
output = self.run_layer_stack(x)
|
||||||
|
|
||||||
|
return output[self.fmap_start_index:]
|
||||||
|
|
||||||
|
def receptive_field(self, output_pos):
|
||||||
|
|
||||||
|
if self.receptive_field_params is not None:
|
||||||
|
s, o, h = self.receptive_field_params['height']
|
||||||
|
h_min = output_pos[0] * s + o + self.start_bin
|
||||||
|
h_max = h_min + h
|
||||||
|
h_min = max(h_min, self.start_bin)
|
||||||
|
h_max = min(h_max, self.stop_bin)
|
||||||
|
|
||||||
|
s, o, w = self.receptive_field_params['width']
|
||||||
|
w_min = output_pos[1] * s + o
|
||||||
|
w_max = w_min + w
|
||||||
|
|
||||||
|
return (h_min, h_max), (w_min, w_max)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def _receptive_field(self, input_dims, output_pos):
|
||||||
|
""" determines receptive field probabilistically via autograd (slow) """
|
||||||
|
|
||||||
|
x = torch.randn((1,) + input_dims, requires_grad=True)
|
||||||
|
|
||||||
|
# run input through layers
|
||||||
|
y = self.run_layer_stack(x)[-1]
|
||||||
|
b, c, h, w = y.shape
|
||||||
|
|
||||||
|
if output_pos[0] >= h or output_pos[1] >= w:
|
||||||
|
raise ValueError("position out of range")
|
||||||
|
|
||||||
|
mask = torch.zeros((b, c, h, w))
|
||||||
|
mask[0, 0, output_pos[0], output_pos[1]] = 1
|
||||||
|
|
||||||
|
(mask * y).sum().backward()
|
||||||
|
|
||||||
|
hs, ws = torch.nonzero(x.grad[0], as_tuple=True)
|
||||||
|
|
||||||
|
h_min, h_max = hs.min().item(), hs.max().item()
|
||||||
|
w_min, w_max = ws.min().item(), ws.max().item()
|
||||||
|
|
||||||
|
return [h_min, h_max], [w_min, w_max]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||||
|
nn.init.orthogonal_(m.weight.data)
|
||||||
|
|
||||||
|
|
||||||
|
def spectrogram(self, x):
|
||||||
|
n_fft, hop_length, win_length = self.resolution
|
||||||
|
x = x.squeeze(1)
|
||||||
|
window = getattr(torch, 'hann_window')(win_length).to(x.device)
|
||||||
|
|
||||||
|
x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length,\
|
||||||
|
window=window, return_complex=True) #[B, F, T]
|
||||||
|
x = torch.abs(x)
|
||||||
|
|
||||||
|
# noise floor following spectral envelope
|
||||||
|
smoothed_x = torch.matmul(self.filterbank, x)
|
||||||
|
noise = torch.randn_like(x) * smoothed_x * self.noise_gain
|
||||||
|
x = x + noise
|
||||||
|
|
||||||
|
# frequency ROI
|
||||||
|
x = x[:, self.start_bin : self.stop_bin + 1, ...]
|
||||||
|
|
||||||
|
return torchaudio.functional.amplitude_to_DB(x,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)#torch.sqrt(x)
|
||||||
|
|
||||||
|
def grad_map(self, x):
|
||||||
|
self.zero_grad()
|
||||||
|
|
||||||
|
n_fft, hop_length, win_length = self.resolution
|
||||||
|
|
||||||
|
window = getattr(torch, 'hann_window')(win_length).to(x.device)
|
||||||
|
|
||||||
|
y = torch.stft(x.squeeze(1), n_fft=n_fft, hop_length=hop_length, win_length=win_length,
|
||||||
|
window=window, return_complex=True) #[B, F, T]
|
||||||
|
y = torch.abs(y)
|
||||||
|
|
||||||
|
specgram = torchaudio.functional.amplitude_to_DB(y,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)
|
||||||
|
|
||||||
|
specgram.requires_grad = True
|
||||||
|
specgram.retain_grad()
|
||||||
|
|
||||||
|
if specgram.grad is not None:
|
||||||
|
specgram.grad.zero_()
|
||||||
|
|
||||||
|
y = specgram[:, self.start_bin : self.stop_bin + 1, ...]
|
||||||
|
|
||||||
|
scores = self.run_layer_stack(y)[-1]
|
||||||
|
|
||||||
|
loss = torch.mean((1 - scores) ** 2)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
return specgram.data[0], torch.abs(specgram.grad)[0]
|
||||||
|
|
||||||
|
def relevance_map(self, x):
|
||||||
|
|
||||||
|
n_fft, hop_length, win_length = self.resolution
|
||||||
|
y = x.view(-1)
|
||||||
|
window = getattr(torch, 'hann_window')(win_length).to(x.device)
|
||||||
|
|
||||||
|
y = torch.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length,\
|
||||||
|
window=window, return_complex=True) #[B, F, T]
|
||||||
|
y = torch.abs(y)
|
||||||
|
|
||||||
|
specgram = torchaudio.functional.amplitude_to_DB(y,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)
|
||||||
|
|
||||||
|
|
||||||
|
scores = self.forward(x)[-1]
|
||||||
|
|
||||||
|
sh, _, h = self.receptive_field_params['height']
|
||||||
|
sw, _, w = self.receptive_field_params['width']
|
||||||
|
kernel = create_kernel(h, w, sh, sw).float().to(scores.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
pad_w = (w + sw - 1) // sw
|
||||||
|
pad_h = (h + sh - 1) // sh
|
||||||
|
padded_scores = F.pad(scores, (pad_w, pad_w, pad_h, pad_h), mode='replicate')
|
||||||
|
# CAVE: padding should be derived from offsets
|
||||||
|
rv = F.conv_transpose2d(padded_scores, kernel, bias=None, stride=(sh, sw), padding=(h//2, w//2))
|
||||||
|
rv = rv[..., pad_h * sh : - pad_h * sh, pad_w * sw : -pad_w * sw]
|
||||||
|
|
||||||
|
relevance = torch.zeros_like(specgram)
|
||||||
|
relevance[..., self.start_bin : self.start_bin + rv.size(-2), : rv.size(-1)] = rv
|
||||||
|
|
||||||
|
|
||||||
|
return specgram, relevance
|
||||||
|
|
||||||
|
|
||||||
|
def lrp(self, x, eps=1e-9, label='both', threshold=0.5, low=None, high=None, verbose=False):
|
||||||
|
""" layer-wise relevance propagation (https://git.tu-berlin.de/gmontavon/lrp-tutorial) """
|
||||||
|
|
||||||
|
# ToDo: this code is highly unsafe as it assumes that layers are nn.Sequential with suitable activations
|
||||||
|
|
||||||
|
def newconv2d(layer,g):
|
||||||
|
|
||||||
|
new_layer = nn.Conv2d(layer.in_channels,
|
||||||
|
layer.out_channels,
|
||||||
|
layer.kernel_size,
|
||||||
|
stride=layer.stride,
|
||||||
|
padding=layer.padding,
|
||||||
|
dilation=layer.dilation,
|
||||||
|
groups=layer.groups)
|
||||||
|
|
||||||
|
try: new_layer.weight = nn.Parameter(g(layer.weight.data.clone()))
|
||||||
|
except AttributeError: pass
|
||||||
|
|
||||||
|
try: new_layer.bias = nn.Parameter(g(layer.bias.data.clone()))
|
||||||
|
except AttributeError: pass
|
||||||
|
|
||||||
|
return new_layer
|
||||||
|
|
||||||
|
bounds = {
|
||||||
|
64: [-85.82449722290039, 2.1755014657974243],
|
||||||
|
128: [-84.49211349487305, 3.5078893899917607],
|
||||||
|
256: [-80.33127822875977, 7.6687201976776125],
|
||||||
|
512: [-73.79328079223633, 14.20672025680542],
|
||||||
|
1024: [-67.59239501953125, 20.40760498046875],
|
||||||
|
2048: [-62.31902580261231, 25.680974197387698],
|
||||||
|
}
|
||||||
|
|
||||||
|
nfft = self.resolution[0]
|
||||||
|
if low is None: low = bounds[nfft][0]
|
||||||
|
if high is None: high = bounds[nfft][1]
|
||||||
|
|
||||||
|
remove_all_weight_norms(self)
|
||||||
|
|
||||||
|
for p in self.parameters():
|
||||||
|
if p.grad is not None:
|
||||||
|
p.grad.zero_()
|
||||||
|
|
||||||
|
num_layers = len(self.layers)
|
||||||
|
X = self.spectrogram(x). detach()
|
||||||
|
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
A = [X.unsqueeze(1)] + [None] * len(self.layers)
|
||||||
|
|
||||||
|
for i in range(num_layers - 1):
|
||||||
|
A[i + 1] = self.layers[i](A[i])
|
||||||
|
|
||||||
|
# initial relevance is last layer without activation
|
||||||
|
r = A[-2]
|
||||||
|
last_layer_rs = [r]
|
||||||
|
layer = self.layers[-1]
|
||||||
|
for sublayer in list(layer)[:-1]:
|
||||||
|
r = sublayer(r)
|
||||||
|
last_layer_rs.append(r)
|
||||||
|
|
||||||
|
|
||||||
|
mask = torch.zeros_like(r)
|
||||||
|
mask.requires_grad_(False)
|
||||||
|
if verbose:
|
||||||
|
print(r.min(), r.max())
|
||||||
|
if label in {'both', 'fake'}:
|
||||||
|
mask[r < -threshold] = 1
|
||||||
|
if label in {'both', 'real'}:
|
||||||
|
mask[r > threshold] = 1
|
||||||
|
r = r * mask
|
||||||
|
|
||||||
|
# backward pass
|
||||||
|
R = [None] * num_layers + [r]
|
||||||
|
|
||||||
|
for l in range(1, num_layers)[::-1]:
|
||||||
|
A[l] = (A[l]).data.requires_grad_(True)
|
||||||
|
|
||||||
|
layer = nn.Sequential(*(list(self.layers[l])[:-1]))
|
||||||
|
z = layer(A[l]) + eps
|
||||||
|
s = (R[l+1] / z).data
|
||||||
|
(z*s).sum().backward()
|
||||||
|
c = A[l].grad
|
||||||
|
R[l] = (A[l] * c).data
|
||||||
|
|
||||||
|
# first layer
|
||||||
|
A[0] = (A[0].data).requires_grad_(True)
|
||||||
|
|
||||||
|
Xl = (torch.zeros_like(A[0].data) + low).requires_grad_(True)
|
||||||
|
Xh = (torch.zeros_like(A[0].data) + high).requires_grad_(True)
|
||||||
|
|
||||||
|
if len(list(self.layers)) > 2:
|
||||||
|
# unsafe way to check for embedding layer
|
||||||
|
embed = list(self.layers[0])[0]
|
||||||
|
conv = list(self.layers[0])[1]
|
||||||
|
|
||||||
|
layer = nn.Sequential(embed, conv)
|
||||||
|
layerl = nn.Sequential(embed, newconv2d(conv, lambda p: p.clamp(min=0)))
|
||||||
|
layerh = nn.Sequential(embed, newconv2d(conv, lambda p: p.clamp(max=0)))
|
||||||
|
|
||||||
|
else:
|
||||||
|
layer = list(self.layers[0])[0]
|
||||||
|
layerl = newconv2d(layer, lambda p: p.clamp(min=0))
|
||||||
|
layerh = newconv2d(layer, lambda p: p.clamp(max=0))
|
||||||
|
|
||||||
|
|
||||||
|
z = layer(A[0])
|
||||||
|
z -= layerl(Xl) + layerh(Xh)
|
||||||
|
s = (R[1] / z).data
|
||||||
|
(z * s).sum().backward()
|
||||||
|
c, cp, cm = A[0].grad, Xl.grad, Xh.grad
|
||||||
|
|
||||||
|
R[0] = (A[0] * c + Xl * cp + Xh * cm)
|
||||||
|
#R[0] = (A[0] * c).data
|
||||||
|
|
||||||
|
return X, R[0].mean(dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def create_3x3_conv_plan(num_layers : int,
|
||||||
|
f_stretch : int,
|
||||||
|
f_down : int,
|
||||||
|
t_stretch : int,
|
||||||
|
t_down : int
|
||||||
|
):
|
||||||
|
|
||||||
|
|
||||||
|
""" creates a stride, dilation, padding plan for a 2d conv network
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_layers (int): number of layers
|
||||||
|
f_stretch (int): log_2 of stretching factor along frequency axis
|
||||||
|
f_down (int): log_2 of downsampling factor along frequency axis
|
||||||
|
t_stretch (int): log_2 of stretching factor along time axis
|
||||||
|
t_down (int): log_2 of downsampling factor along time axis
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list(list(tuple)): list containing entries [(stride_t, stride_f), (dilation_t, dilation_f), (padding_t, padding_f)]
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert num_layers > 0 and t_stretch >= 0 and t_down >= 0 and f_stretch >= 0 and f_down >= 0
|
||||||
|
assert f_stretch < num_layers and t_stretch < num_layers
|
||||||
|
|
||||||
|
def process_dimension(n_layers, stretch, down):
|
||||||
|
|
||||||
|
stack_layers = n_layers - 1
|
||||||
|
|
||||||
|
stride_layers = min(min(down, stretch) , stack_layers)
|
||||||
|
dilation_layers = max(min(stack_layers - stride_layers - 1, stretch - stride_layers), 0)
|
||||||
|
final_stride = 2 ** (max(down - stride_layers, 0))
|
||||||
|
|
||||||
|
final_dilation = 1
|
||||||
|
if stride_layers < stack_layers and stretch - stride_layers - dilation_layers > 0:
|
||||||
|
final_dilation = 2
|
||||||
|
|
||||||
|
strides, dilations, paddings = [], [], []
|
||||||
|
processed_layers = 0
|
||||||
|
current_dilation = 1
|
||||||
|
|
||||||
|
for _ in range(stride_layers):
|
||||||
|
# increase receptive field and downsample via stride = 2
|
||||||
|
strides.append(2)
|
||||||
|
dilations.append(1)
|
||||||
|
paddings.append(1)
|
||||||
|
processed_layers += 1
|
||||||
|
|
||||||
|
if processed_layers < stack_layers:
|
||||||
|
strides.append(1)
|
||||||
|
dilations.append(1)
|
||||||
|
paddings.append(1)
|
||||||
|
processed_layers += 1
|
||||||
|
|
||||||
|
for _ in range(dilation_layers):
|
||||||
|
# increase receptive field via dilation = 2
|
||||||
|
strides.append(1)
|
||||||
|
current_dilation *= 2
|
||||||
|
dilations.append(current_dilation)
|
||||||
|
paddings.append(current_dilation)
|
||||||
|
processed_layers += 1
|
||||||
|
|
||||||
|
while processed_layers < n_layers - 1:
|
||||||
|
# fill up with std layers
|
||||||
|
strides.append(1)
|
||||||
|
dilations.append(current_dilation)
|
||||||
|
paddings.append(current_dilation)
|
||||||
|
processed_layers += 1
|
||||||
|
|
||||||
|
# final layer
|
||||||
|
strides.append(final_stride)
|
||||||
|
current_dilation * final_dilation
|
||||||
|
dilations.append(current_dilation)
|
||||||
|
paddings.append(current_dilation)
|
||||||
|
processed_layers += 1
|
||||||
|
|
||||||
|
assert processed_layers == n_layers
|
||||||
|
|
||||||
|
return strides, dilations, paddings
|
||||||
|
|
||||||
|
t_strides, t_dilations, t_paddings = process_dimension(num_layers, t_stretch, t_down)
|
||||||
|
f_strides, f_dilations, f_paddings = process_dimension(num_layers, f_stretch, f_down)
|
||||||
|
|
||||||
|
plan = []
|
||||||
|
|
||||||
|
for i in range(num_layers):
|
||||||
|
plan.append([
|
||||||
|
(f_strides[i], t_strides[i]),
|
||||||
|
(f_dilations[i], t_dilations[i]),
|
||||||
|
(f_paddings[i], t_paddings[i]),
|
||||||
|
])
|
||||||
|
|
||||||
|
return plan
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorExperimental(SpecDiscriminatorBase):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
resolution,
|
||||||
|
fs=16000,
|
||||||
|
freq_roi=[50, 7400],
|
||||||
|
noise_gain=0,
|
||||||
|
num_channels=16,
|
||||||
|
max_channels=512,
|
||||||
|
num_layers=5,
|
||||||
|
use_spectral_norm=False):
|
||||||
|
|
||||||
|
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||||
|
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.num_channels_max = max_channels
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
stride = (2, 1)
|
||||||
|
padding= (1, 1)
|
||||||
|
in_channels = 1 + 2
|
||||||
|
out_channels = self.num_channels
|
||||||
|
for _ in range(self.num_layers):
|
||||||
|
layers.append(
|
||||||
|
nn.Sequential(
|
||||||
|
FrequencyPositionalEmbedding(),
|
||||||
|
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
|
||||||
|
nn.ReLU(inplace=True)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
in_channels = out_channels + 2
|
||||||
|
out_channels = min(2 * out_channels, self.num_channels_max)
|
||||||
|
|
||||||
|
layers.append(
|
||||||
|
nn.Sequential(
|
||||||
|
FrequencyPositionalEmbedding(),
|
||||||
|
norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding)),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
|
||||||
|
|
||||||
|
# bias biases
|
||||||
|
bias_val = 0.1
|
||||||
|
with torch.no_grad():
|
||||||
|
for name, weight in self.named_parameters():
|
||||||
|
if 'bias' in name:
|
||||||
|
weight = weight + bias_val
|
||||||
|
|
||||||
|
|
||||||
|
configs = {
|
||||||
|
'f_down': {
|
||||||
|
'stretch' : {
|
||||||
|
64 : (0, 0),
|
||||||
|
128: (1, 0),
|
||||||
|
256: (2, 0),
|
||||||
|
512: (3, 0),
|
||||||
|
1024: (4, 0),
|
||||||
|
2048: (5, 0)
|
||||||
|
},
|
||||||
|
'down' : {
|
||||||
|
64 : (0, 0),
|
||||||
|
128: (1, 0),
|
||||||
|
256: (2, 0),
|
||||||
|
512: (3, 0),
|
||||||
|
1024: (4, 0),
|
||||||
|
2048: (5, 0)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'ft_down': {
|
||||||
|
'stretch' : {
|
||||||
|
64 : (0, 4),
|
||||||
|
128: (1, 3),
|
||||||
|
256: (2, 2),
|
||||||
|
512: (3, 1),
|
||||||
|
1024: (4, 0),
|
||||||
|
2048: (5, 0)
|
||||||
|
},
|
||||||
|
'down' : {
|
||||||
|
64 : (0, 4),
|
||||||
|
128: (1, 3),
|
||||||
|
256: (2, 2),
|
||||||
|
512: (3, 1),
|
||||||
|
1024: (4, 0),
|
||||||
|
2048: (5, 0)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'dilated': {
|
||||||
|
'stretch' : {
|
||||||
|
64 : (0, 4),
|
||||||
|
128: (1, 3),
|
||||||
|
256: (2, 2),
|
||||||
|
512: (3, 1),
|
||||||
|
1024: (4, 0),
|
||||||
|
2048: (5, 0)
|
||||||
|
},
|
||||||
|
'down' : {
|
||||||
|
64 : (0, 0),
|
||||||
|
128: (0, 0),
|
||||||
|
256: (0, 0),
|
||||||
|
512: (0, 0),
|
||||||
|
1024: (0, 0),
|
||||||
|
2048: (0, 0)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'mixed': {
|
||||||
|
'stretch' : {
|
||||||
|
64 : (0, 4),
|
||||||
|
128: (1, 3),
|
||||||
|
256: (2, 2),
|
||||||
|
512: (3, 1),
|
||||||
|
1024: (4, 0),
|
||||||
|
2048: (5, 0)
|
||||||
|
},
|
||||||
|
'down' : {
|
||||||
|
64 : (0, 0),
|
||||||
|
128: (1, 0),
|
||||||
|
256: (2, 0),
|
||||||
|
512: (3, 0),
|
||||||
|
1024: (4, 0),
|
||||||
|
2048: (5, 0)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorMagFree(SpecDiscriminatorBase):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
resolution,
|
||||||
|
fs=16000,
|
||||||
|
freq_roi=[50, 7400],
|
||||||
|
noise_gain=0,
|
||||||
|
num_channels=16,
|
||||||
|
max_channels=256,
|
||||||
|
num_layers=5,
|
||||||
|
use_spectral_norm=False,
|
||||||
|
design=None):
|
||||||
|
|
||||||
|
if design is None:
|
||||||
|
raise ValueError('error: arch required in DiscriminatorMagFree')
|
||||||
|
|
||||||
|
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||||
|
|
||||||
|
stretch = configs[design]['stretch'][resolution[0]]
|
||||||
|
down = configs[design]['down'][resolution[0]]
|
||||||
|
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.num_channels_max = max_channels
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.stretch = stretch
|
||||||
|
self.down = down
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
plan = create_3x3_conv_plan(num_layers + 1, stretch[0], down[0], stretch[1], down[1])
|
||||||
|
in_channels = 1 + 2
|
||||||
|
out_channels = self.num_channels
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
layers.append(
|
||||||
|
nn.Sequential(
|
||||||
|
FrequencyPositionalEmbedding(),
|
||||||
|
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=plan[i][0], dilation=plan[i][1], padding=plan[i][2])),
|
||||||
|
nn.ReLU(inplace=True)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
in_channels = out_channels + 2
|
||||||
|
# product over strides
|
||||||
|
channel_factor = plan[i][0][0] * plan[i][0][1]
|
||||||
|
out_channels = min(channel_factor * out_channels, self.num_channels_max)
|
||||||
|
|
||||||
|
layers.append(
|
||||||
|
nn.Sequential(
|
||||||
|
FrequencyPositionalEmbedding(),
|
||||||
|
norm_f(nn.Conv2d(in_channels, 1, (3, 3), stride=plan[-1][0], dilation=plan[-1][1], padding=plan[-1][2])),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# for layer in layers:
|
||||||
|
# print(layer)
|
||||||
|
|
||||||
|
# print("end\n\n")
|
||||||
|
|
||||||
|
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
|
||||||
|
|
||||||
|
# bias biases
|
||||||
|
bias_val = 0.1
|
||||||
|
with torch.no_grad():
|
||||||
|
for name, weight in self.named_parameters():
|
||||||
|
if 'bias' in name:
|
||||||
|
weight = weight + bias_val
|
||||||
|
|
||||||
|
class DiscriminatorMagFreqPosition(SpecDiscriminatorBase):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
resolution,
|
||||||
|
fs=16000,
|
||||||
|
freq_roi=[50, 7400],
|
||||||
|
noise_gain=0,
|
||||||
|
num_channels=16,
|
||||||
|
max_channels=512,
|
||||||
|
num_layers=5,
|
||||||
|
use_spectral_norm=False):
|
||||||
|
|
||||||
|
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||||
|
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.num_channels_max = max_channels
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
stride = (2, 1)
|
||||||
|
padding= (1, 1)
|
||||||
|
in_channels = 1 + 2
|
||||||
|
out_channels = self.num_channels
|
||||||
|
for _ in range(self.num_layers):
|
||||||
|
layers.append(
|
||||||
|
nn.Sequential(
|
||||||
|
FrequencyPositionalEmbedding(),
|
||||||
|
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
in_channels = out_channels + 2
|
||||||
|
out_channels = min(2 * out_channels, self.num_channels_max)
|
||||||
|
|
||||||
|
layers.append(
|
||||||
|
nn.Sequential(
|
||||||
|
FrequencyPositionalEmbedding(),
|
||||||
|
norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorMag2dPositional(SpecDiscriminatorBase):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
resolution,
|
||||||
|
fs=16000,
|
||||||
|
freq_roi=[50, 7400],
|
||||||
|
noise_gain=0,
|
||||||
|
num_channels=16,
|
||||||
|
max_channels=512,
|
||||||
|
num_layers=5,
|
||||||
|
d=5,
|
||||||
|
use_spectral_norm=False):
|
||||||
|
|
||||||
|
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||||
|
self.resolution = resolution
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.num_channels_max = max_channels
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.d = d
|
||||||
|
embedding_dim = 4 * d
|
||||||
|
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
stride = (2, 2)
|
||||||
|
padding= (1, 1)
|
||||||
|
in_channels = 1 + embedding_dim
|
||||||
|
out_channels = self.num_channels
|
||||||
|
for _ in range(self.num_layers):
|
||||||
|
layers.append(
|
||||||
|
nn.Sequential(
|
||||||
|
PositionalEmbedding2D(d),
|
||||||
|
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
in_channels = out_channels + embedding_dim
|
||||||
|
out_channels = min(2 * out_channels, self.num_channels_max)
|
||||||
|
|
||||||
|
|
||||||
|
layers.append(
|
||||||
|
nn.Sequential(
|
||||||
|
PositionalEmbedding2D(),
|
||||||
|
norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorMag(SpecDiscriminatorBase):
|
||||||
|
def __init__(self,
|
||||||
|
resolution,
|
||||||
|
fs=16000,
|
||||||
|
freq_roi=[50, 7400],
|
||||||
|
noise_gain=0,
|
||||||
|
num_channels=32,
|
||||||
|
num_layers=5,
|
||||||
|
use_spectral_norm=False):
|
||||||
|
|
||||||
|
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||||
|
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
stride = (1, 1)
|
||||||
|
padding= (1, 1)
|
||||||
|
in_channels = 1
|
||||||
|
out_channels = self.num_channels
|
||||||
|
for _ in range(self.num_layers):
|
||||||
|
layers.append(
|
||||||
|
nn.Sequential(
|
||||||
|
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
in_channels = out_channels
|
||||||
|
|
||||||
|
layers.append(norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding)))
|
||||||
|
|
||||||
|
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
|
||||||
|
|
||||||
|
|
||||||
|
discriminators = {
|
||||||
|
'mag': DiscriminatorMag,
|
||||||
|
'freqpos': DiscriminatorMagFreqPosition,
|
||||||
|
'2dpos': DiscriminatorMag2dPositional,
|
||||||
|
'experimental': DiscriminatorExperimental,
|
||||||
|
'free': DiscriminatorMagFree
|
||||||
|
}
|
||||||
|
|
||||||
|
class TFDMultiResolutionDiscriminator(torch.nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
fft_sizes_16k=[64, 128, 256, 512, 1024, 2048],
|
||||||
|
architecture='mag',
|
||||||
|
fs=16000,
|
||||||
|
freq_roi=[50, 7400],
|
||||||
|
noise_gain=0,
|
||||||
|
use_spectral_norm=False,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
fft_sizes = [int(round(fft_size_16k * fs / 16000)) for fft_size_16k in fft_sizes_16k]
|
||||||
|
|
||||||
|
resolutions = [[n_fft, n_fft // 4, n_fft] for n_fft in fft_sizes]
|
||||||
|
|
||||||
|
|
||||||
|
Disc = discriminators[architecture]
|
||||||
|
|
||||||
|
discs = [Disc(resolutions[i], fs=fs, freq_roi=freq_roi, noise_gain=noise_gain, use_spectral_norm=use_spectral_norm, **kwargs) for i in range(len(resolutions))]
|
||||||
|
|
||||||
|
self.discriminators = nn.ModuleList(discs)
|
||||||
|
|
||||||
|
def forward(self, y):
|
||||||
|
outputs = []
|
||||||
|
|
||||||
|
for disc in self.discriminators:
|
||||||
|
outputs.append(disc(y))
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class FWGAN_disc_wrapper(nn.Module):
|
||||||
|
def __init__(self, disc):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.disc = disc
|
||||||
|
|
||||||
|
def forward(self, y, y_hat):
|
||||||
|
|
||||||
|
out_real = self.disc(y)
|
||||||
|
out_fake = self.disc(y_hat)
|
||||||
|
|
||||||
|
y_d_rs = []
|
||||||
|
y_d_gs = []
|
||||||
|
fmap_rs = []
|
||||||
|
fmap_gs = []
|
||||||
|
|
||||||
|
for y_real, y_fake in zip(out_real, out_fake):
|
||||||
|
y_d_rs.append(y_real[-1])
|
||||||
|
y_d_gs.append(y_fake[-1])
|
||||||
|
fmap_rs.append(y_real[:-1])
|
||||||
|
fmap_gs.append(y_fake[:-1])
|
||||||
|
|
||||||
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
254
dnn/torch/osce/models/lavoce.py
Normal file
254
dnn/torch/osce/models/lavoce.py
Normal file
|
@ -0,0 +1,254 @@
|
||||||
|
"""
|
||||||
|
/* Copyright (c) 2023 Amazon
|
||||||
|
Written by Jan Buethe */
|
||||||
|
/*
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions
|
||||||
|
are met:
|
||||||
|
|
||||||
|
- Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
- Redistributions in binary form must reproduce the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer in the
|
||||||
|
documentation and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||||
|
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||||
|
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||||
|
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*/
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
|
||||||
|
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
|
||||||
|
from utils.layers.td_shaper import TDShaper
|
||||||
|
from utils.layers.noise_shaper import NoiseShaper
|
||||||
|
from utils.complexity import _conv1d_flop_count
|
||||||
|
from utils.endoscopy import write_data
|
||||||
|
|
||||||
|
from models.nns_base import NNSBase
|
||||||
|
from models.lpcnet_feature_net import LPCNetFeatureNet
|
||||||
|
from .scale_embedding import ScaleEmbedding
|
||||||
|
|
||||||
|
class LaVoce(nn.Module):
|
||||||
|
""" Linear-Adaptive VOCodEr """
|
||||||
|
FEATURE_FRAME_SIZE=160
|
||||||
|
FRAME_SIZE=80
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_features=20,
|
||||||
|
pitch_embedding_dim=64,
|
||||||
|
cond_dim=256,
|
||||||
|
pitch_max=300,
|
||||||
|
kernel_size=15,
|
||||||
|
preemph=0.85,
|
||||||
|
comb_gain_limit_db=-6,
|
||||||
|
global_gain_limits_db=[-6, 6],
|
||||||
|
conv_gain_limits_db=[-6, 6],
|
||||||
|
norm_p=2,
|
||||||
|
avg_pool_k=4,
|
||||||
|
pulses=False):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
self.num_features = num_features
|
||||||
|
self.cond_dim = cond_dim
|
||||||
|
self.pitch_max = pitch_max
|
||||||
|
self.pitch_embedding_dim = pitch_embedding_dim
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.preemph = preemph
|
||||||
|
self.pulses = pulses
|
||||||
|
|
||||||
|
assert self.FEATURE_FRAME_SIZE % self.FRAME_SIZE == 0
|
||||||
|
self.upsamp_factor = self.FEATURE_FRAME_SIZE // self.FRAME_SIZE
|
||||||
|
|
||||||
|
# pitch embedding
|
||||||
|
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
|
||||||
|
|
||||||
|
# feature net
|
||||||
|
self.feature_net = LPCNetFeatureNet(num_features + pitch_embedding_dim, cond_dim, self.upsamp_factor)
|
||||||
|
|
||||||
|
# noise shaper
|
||||||
|
self.noise_shaper = NoiseShaper(cond_dim, self.FRAME_SIZE)
|
||||||
|
|
||||||
|
# comb filters
|
||||||
|
left_pad = self.kernel_size // 2
|
||||||
|
right_pad = self.kernel_size - 1 - left_pad
|
||||||
|
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
|
||||||
|
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
|
||||||
|
|
||||||
|
|
||||||
|
self.af_prescale = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||||
|
self.af_mix = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||||
|
|
||||||
|
# spectral shaping
|
||||||
|
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||||
|
|
||||||
|
# non-linear transforms
|
||||||
|
self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=True)
|
||||||
|
self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k)
|
||||||
|
self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k)
|
||||||
|
|
||||||
|
# combinators
|
||||||
|
self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||||
|
self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||||
|
self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||||
|
|
||||||
|
# feature transforms
|
||||||
|
self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, 2)
|
||||||
|
self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, 2)
|
||||||
|
self.post_af1 = nn.Conv1d(cond_dim, cond_dim, 2)
|
||||||
|
self.post_af2 = nn.Conv1d(cond_dim, cond_dim, 2)
|
||||||
|
self.post_af3 = nn.Conv1d(cond_dim, cond_dim, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def create_phase_signals(self, periods, pulses=False):
|
||||||
|
|
||||||
|
batch_size = periods.size(0)
|
||||||
|
progression = torch.arange(1, self.FRAME_SIZE + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
|
||||||
|
progression = torch.repeat_interleave(progression, batch_size, 0)
|
||||||
|
|
||||||
|
phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
|
||||||
|
chunks = []
|
||||||
|
for sframe in range(periods.size(1)):
|
||||||
|
f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
|
||||||
|
|
||||||
|
if pulses:
|
||||||
|
alpha = torch.cos(f)
|
||||||
|
chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
|
||||||
|
pulse_a = torch.relu(chunk_sin - alpha) / (1 - alpha)
|
||||||
|
pulse_b = torch.relu(-chunk_sin - alpha) / (1 - alpha)
|
||||||
|
|
||||||
|
chunk = torch.cat((pulse_a, pulse_b), dim = 1)
|
||||||
|
else:
|
||||||
|
chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
|
||||||
|
chunk_cos = torch.cos(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
|
||||||
|
|
||||||
|
chunk = torch.cat((chunk_sin, chunk_cos), dim = 1)
|
||||||
|
|
||||||
|
phase0 = phase0 + self.FRAME_SIZE * f
|
||||||
|
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
phase_signals = torch.cat(chunks, dim=-1)
|
||||||
|
|
||||||
|
return phase_signals
|
||||||
|
|
||||||
|
def flop_count(self, rate=16000, verbose=False):
|
||||||
|
|
||||||
|
frame_rate = rate / self.FRAME_SIZE
|
||||||
|
|
||||||
|
# feature net
|
||||||
|
feature_net_flops = self.feature_net.flop_count(frame_rate)
|
||||||
|
comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
|
||||||
|
af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate) + self.af_prescale.flop_count(rate) + self.af_mix.flop_count(rate)
|
||||||
|
feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate)
|
||||||
|
+ _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate))
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
|
||||||
|
print(f"comb filters: {comb_flops / 1e6} MFLOPS")
|
||||||
|
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
|
||||||
|
print(f"feature transforms: {feature_flops / 1e6} MFLOPS")
|
||||||
|
|
||||||
|
return feature_net_flops + comb_flops + af_flops + feature_flops
|
||||||
|
|
||||||
|
def feature_transform(self, f, layer):
|
||||||
|
f = f.permute(0, 2, 1)
|
||||||
|
f = F.pad(f, [1, 0])
|
||||||
|
f = torch.tanh(layer(f))
|
||||||
|
return f.permute(0, 2, 1)
|
||||||
|
|
||||||
|
def forward(self, features, periods, debug=False):
|
||||||
|
|
||||||
|
periods = periods.squeeze(-1)
|
||||||
|
pitch_embedding = self.pitch_embedding(periods)
|
||||||
|
|
||||||
|
full_features = torch.cat((features, pitch_embedding), dim=-1)
|
||||||
|
cf = self.feature_net(full_features)
|
||||||
|
|
||||||
|
# upsample periods
|
||||||
|
periods = torch.repeat_interleave(periods, self.upsamp_factor, 1)
|
||||||
|
|
||||||
|
# pre-net
|
||||||
|
ref_phase = torch.tanh(self.create_phase_signals(periods))
|
||||||
|
x = self.af_prescale(ref_phase, cf)
|
||||||
|
noise = self.noise_shaper(cf)
|
||||||
|
y = self.af_mix(torch.cat((x, noise), dim=1), cf)
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
ch0 = y[0,0,:].detach().cpu().numpy()
|
||||||
|
ch1 = y[0,1,:].detach().cpu().numpy()
|
||||||
|
ch0 = (2**15 * ch0 / np.max(ch0)).astype(np.int16)
|
||||||
|
ch1 = (2**15 * ch1 / np.max(ch1)).astype(np.int16)
|
||||||
|
write_data('prior_channel0', ch0, 16000)
|
||||||
|
write_data('prior_channel1', ch1, 16000)
|
||||||
|
|
||||||
|
# temporal shaping + innovating
|
||||||
|
y1 = y[:, 0:1, :]
|
||||||
|
y2 = self.tdshape1(y[:, 1:2, :], cf)
|
||||||
|
y = torch.cat((y1, y2), dim=1)
|
||||||
|
y = self.af2(y, cf, debug=debug)
|
||||||
|
cf = self.feature_transform(cf, self.post_af2)
|
||||||
|
|
||||||
|
y1 = y[:, 0:1, :]
|
||||||
|
y2 = self.tdshape2(y[:, 1:2, :], cf)
|
||||||
|
y = torch.cat((y1, y2), dim=1)
|
||||||
|
y = self.af3(y, cf, debug=debug)
|
||||||
|
cf = self.feature_transform(cf, self.post_af3)
|
||||||
|
|
||||||
|
# spectral shaping
|
||||||
|
y = self.cf1(y, cf, periods, debug=debug)
|
||||||
|
cf = self.feature_transform(cf, self.post_cf1)
|
||||||
|
|
||||||
|
y = self.cf2(y, cf, periods, debug=debug)
|
||||||
|
cf = self.feature_transform(cf, self.post_cf2)
|
||||||
|
|
||||||
|
y = self.af1(y, cf, debug=debug)
|
||||||
|
cf = self.feature_transform(cf, self.post_af1)
|
||||||
|
|
||||||
|
# final temporal env adjustment
|
||||||
|
y1 = y[:, 0:1, :]
|
||||||
|
y2 = self.tdshape3(y[:, 1:2, :], cf)
|
||||||
|
y = torch.cat((y1, y2), dim=1)
|
||||||
|
y = self.af4(y, cf, debug=debug)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
def process(self, features, periods, debug=False):
|
||||||
|
|
||||||
|
self.eval()
|
||||||
|
device = next(iter(self.parameters())).device
|
||||||
|
with torch.no_grad():
|
||||||
|
|
||||||
|
# run model
|
||||||
|
f = features.unsqueeze(0).to(device)
|
||||||
|
p = periods.unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
y = self.forward(f, p, debug=debug).squeeze()
|
||||||
|
|
||||||
|
# deemphasis
|
||||||
|
if self.preemph > 0:
|
||||||
|
for i in range(len(y) - 1):
|
||||||
|
y[i + 1] += self.preemph * y[i]
|
||||||
|
|
||||||
|
# clip to valid range
|
||||||
|
out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short()
|
||||||
|
|
||||||
|
return out
|
91
dnn/torch/osce/models/lpcnet_feature_net.py
Normal file
91
dnn/torch/osce/models/lpcnet_feature_net.py
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
"""
|
||||||
|
/* Copyright (c) 2023 Amazon
|
||||||
|
Written by Jan Buethe */
|
||||||
|
/*
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions
|
||||||
|
are met:
|
||||||
|
|
||||||
|
- Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
- Redistributions in binary form must reproduce the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer in the
|
||||||
|
documentation and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||||
|
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||||
|
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||||
|
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*/
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from utils.complexity import _conv1d_flop_count
|
||||||
|
|
||||||
|
class LPCNetFeatureNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
feature_dim=84,
|
||||||
|
num_channels=256,
|
||||||
|
upsamp_factor=2,
|
||||||
|
lookahead=True):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.feature_dim = feature_dim
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.upsamp_factor = upsamp_factor
|
||||||
|
self.lookahead = lookahead
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv1d(feature_dim, num_channels, 3)
|
||||||
|
self.conv2 = nn.Conv1d(num_channels, num_channels, 3)
|
||||||
|
|
||||||
|
self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
|
||||||
|
|
||||||
|
self.tconv = nn.ConvTranspose1d(num_channels, num_channels, upsamp_factor, upsamp_factor)
|
||||||
|
|
||||||
|
def flop_count(self, rate=100):
|
||||||
|
count = 0
|
||||||
|
for conv in self.conv1, self.conv2, self.tconv:
|
||||||
|
count += _conv1d_flop_count(conv, rate)
|
||||||
|
|
||||||
|
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, features, state=None):
|
||||||
|
""" features shape: (batch_size, num_frames, feature_dim) """
|
||||||
|
|
||||||
|
batch_size = features.size(0)
|
||||||
|
|
||||||
|
if state is None:
|
||||||
|
state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
|
||||||
|
|
||||||
|
|
||||||
|
features = features.permute(0, 2, 1)
|
||||||
|
if self.lookahead:
|
||||||
|
c = torch.tanh(self.conv1(F.pad(features, [1, 1])))
|
||||||
|
c = torch.tanh(self.conv2(F.pad(c, [2, 0])))
|
||||||
|
else:
|
||||||
|
c = torch.tanh(self.conv1(F.pad(features, [2, 0])))
|
||||||
|
c = torch.tanh(self.conv2(F.pad(c, [2, 0])))
|
||||||
|
|
||||||
|
c = torch.tanh(self.tconv(c))
|
||||||
|
|
||||||
|
c = c.permute(0, 2, 1)
|
||||||
|
|
||||||
|
c, _ = self.gru(c, state)
|
||||||
|
|
||||||
|
return c
|
|
@ -1,3 +1,31 @@
|
||||||
|
"""
|
||||||
|
/* Copyright (c) 2023 Amazon
|
||||||
|
Written by Jan Buethe */
|
||||||
|
/*
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions
|
||||||
|
are met:
|
||||||
|
|
||||||
|
- Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
- Redistributions in binary form must reproduce the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer in the
|
||||||
|
documentation and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||||
|
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||||
|
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||||
|
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*/
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
103
dnn/torch/osce/test_vocoder.py
Normal file
103
dnn/torch/osce/test_vocoder.py
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
"""
|
||||||
|
/* Copyright (c) 2023 Amazon
|
||||||
|
Written by Jan Buethe */
|
||||||
|
/*
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions
|
||||||
|
are met:
|
||||||
|
|
||||||
|
- Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
- Redistributions in binary form must reproduce the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer in the
|
||||||
|
documentation and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||||
|
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||||
|
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||||
|
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*/
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from scipy.io import wavfile
|
||||||
|
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
|
||||||
|
from models import model_dict
|
||||||
|
from utils.lpcnet_features import load_lpcnet_features
|
||||||
|
from utils import endoscopy
|
||||||
|
|
||||||
|
debug = False
|
||||||
|
if debug:
|
||||||
|
args = type('dummy', (object,),
|
||||||
|
{
|
||||||
|
'input' : 'testitems/all_0_orig.se',
|
||||||
|
'checkpoint' : 'testout/checkpoints/checkpoint_epoch_5.pth',
|
||||||
|
'output' : 'out.wav',
|
||||||
|
})()
|
||||||
|
else:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument('input', type=str, help='path to input features')
|
||||||
|
parser.add_argument('checkpoint', type=str, help='checkpoint file')
|
||||||
|
parser.add_argument('output', type=str, help='output file')
|
||||||
|
parser.add_argument('--debug', action='store_true', help='enables debug output')
|
||||||
|
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(2)
|
||||||
|
|
||||||
|
input_folder = args.input
|
||||||
|
checkpoint_file = args.checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
output_file = args.output
|
||||||
|
if not output_file.endswith('.wav'):
|
||||||
|
output_file += '.wav'
|
||||||
|
|
||||||
|
checkpoint = torch.load(checkpoint_file, map_location="cpu")
|
||||||
|
|
||||||
|
# check model
|
||||||
|
if not 'name' in checkpoint['setup']['model']:
|
||||||
|
print(f'warning: did not find model name entry in setup, using pitchpostfilter per default')
|
||||||
|
model_name = 'pitchpostfilter'
|
||||||
|
else:
|
||||||
|
model_name = checkpoint['setup']['model']['name']
|
||||||
|
|
||||||
|
model = model_dict[model_name](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
|
||||||
|
|
||||||
|
model.load_state_dict(checkpoint['state_dict'])
|
||||||
|
|
||||||
|
# generate model input
|
||||||
|
setup = checkpoint['setup']
|
||||||
|
testdata = load_lpcnet_features(input_folder)
|
||||||
|
features = testdata['features']
|
||||||
|
periods = testdata['periods']
|
||||||
|
|
||||||
|
if args.debug:
|
||||||
|
endoscopy.init()
|
||||||
|
|
||||||
|
start = time()
|
||||||
|
output = model.process(features, periods, debug=args.debug)
|
||||||
|
elapsed = time() - start
|
||||||
|
print(f"[timing] inference took {elapsed * 1000} ms")
|
||||||
|
|
||||||
|
wavfile.write(output_file, 16000, output.cpu().numpy())
|
||||||
|
|
||||||
|
if args.debug:
|
||||||
|
endoscopy.close()
|
287
dnn/torch/osce/train_vocoder.py
Normal file
287
dnn/torch/osce/train_vocoder.py
Normal file
|
@ -0,0 +1,287 @@
|
||||||
|
"""
|
||||||
|
/* Copyright (c) 2023 Amazon
|
||||||
|
Written by Jan Buethe */
|
||||||
|
/*
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions
|
||||||
|
are met:
|
||||||
|
|
||||||
|
- Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
- Redistributions in binary form must reproduce the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer in the
|
||||||
|
documentation and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||||
|
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||||
|
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||||
|
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*/
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
try:
|
||||||
|
import git
|
||||||
|
has_git = True
|
||||||
|
except:
|
||||||
|
has_git = False
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
|
||||||
|
from scipy.io import wavfile
|
||||||
|
|
||||||
|
import pesq
|
||||||
|
|
||||||
|
from data import LPCNetVocodingDataset
|
||||||
|
from models import model_dict
|
||||||
|
from engine.vocoder_engine import train_one_epoch, evaluate
|
||||||
|
|
||||||
|
|
||||||
|
from utils.lpcnet_features import load_lpcnet_features
|
||||||
|
from utils.misc import count_parameters
|
||||||
|
|
||||||
|
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument('setup', type=str, help='setup yaml file')
|
||||||
|
parser.add_argument('output', type=str, help='output path')
|
||||||
|
parser.add_argument('--device', type=str, help='compute device', default=None)
|
||||||
|
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
|
||||||
|
parser.add_argument('--test-features', type=str, help='path to features for testing', default=None)
|
||||||
|
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(4)
|
||||||
|
|
||||||
|
with open(args.setup, 'r') as f:
|
||||||
|
setup = yaml.load(f.read(), yaml.FullLoader)
|
||||||
|
|
||||||
|
checkpoint_prefix = 'checkpoint'
|
||||||
|
output_prefix = 'output'
|
||||||
|
setup_name = 'setup.yml'
|
||||||
|
output_file='out.txt'
|
||||||
|
|
||||||
|
|
||||||
|
# check model
|
||||||
|
if not 'name' in setup['model']:
|
||||||
|
print(f'warning: did not find model entry in setup, using default PitchPostFilter')
|
||||||
|
model_name = 'pitchpostfilter'
|
||||||
|
else:
|
||||||
|
model_name = setup['model']['name']
|
||||||
|
|
||||||
|
# prepare output folder
|
||||||
|
if os.path.exists(args.output):
|
||||||
|
print("warning: output folder exists")
|
||||||
|
|
||||||
|
reply = input('continue? (y/n): ')
|
||||||
|
while reply not in {'y', 'n'}:
|
||||||
|
reply = input('continue? (y/n): ')
|
||||||
|
|
||||||
|
if reply == 'n':
|
||||||
|
os._exit()
|
||||||
|
else:
|
||||||
|
os.makedirs(args.output, exist_ok=True)
|
||||||
|
|
||||||
|
checkpoint_dir = os.path.join(args.output, 'checkpoints')
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# add repo info to setup
|
||||||
|
if has_git:
|
||||||
|
working_dir = os.path.split(__file__)[0]
|
||||||
|
try:
|
||||||
|
repo = git.Repo(working_dir)
|
||||||
|
setup['repo'] = dict()
|
||||||
|
hash = repo.head.object.hexsha
|
||||||
|
urls = list(repo.remote().urls)
|
||||||
|
is_dirty = repo.is_dirty()
|
||||||
|
|
||||||
|
if is_dirty:
|
||||||
|
print("warning: repo is dirty")
|
||||||
|
|
||||||
|
setup['repo']['hash'] = hash
|
||||||
|
setup['repo']['urls'] = urls
|
||||||
|
setup['repo']['dirty'] = is_dirty
|
||||||
|
except:
|
||||||
|
has_git = False
|
||||||
|
|
||||||
|
# dump setup
|
||||||
|
with open(os.path.join(args.output, setup_name), 'w') as f:
|
||||||
|
yaml.dump(setup, f)
|
||||||
|
|
||||||
|
ref = None
|
||||||
|
# prepare inference test if wanted
|
||||||
|
inference_test = False
|
||||||
|
if type(args.test_features) != type(None):
|
||||||
|
test_features = load_lpcnet_features(args.test_features)
|
||||||
|
features = test_features['features']
|
||||||
|
periods = test_features['periods']
|
||||||
|
inference_folder = os.path.join(args.output, 'inference_test')
|
||||||
|
os.makedirs(inference_folder, exist_ok=True)
|
||||||
|
inference_test = True
|
||||||
|
|
||||||
|
|
||||||
|
# training parameters
|
||||||
|
batch_size = setup['training']['batch_size']
|
||||||
|
epochs = setup['training']['epochs']
|
||||||
|
lr = setup['training']['lr']
|
||||||
|
lr_decay_factor = setup['training']['lr_decay_factor']
|
||||||
|
|
||||||
|
# load training dataset
|
||||||
|
data_config = setup['data']
|
||||||
|
data = LPCNetVocodingDataset(setup['dataset'], **data_config)
|
||||||
|
|
||||||
|
# load validation dataset if given
|
||||||
|
if 'validation_dataset' in setup:
|
||||||
|
validation_data = LPCNetVocodingDataset(setup['validation_dataset'], **data_config)
|
||||||
|
|
||||||
|
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=8)
|
||||||
|
|
||||||
|
run_validation = True
|
||||||
|
else:
|
||||||
|
run_validation = False
|
||||||
|
|
||||||
|
# create model
|
||||||
|
model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
|
||||||
|
|
||||||
|
if args.initial_checkpoint is not None:
|
||||||
|
print(f"loading state dict from {args.initial_checkpoint}...")
|
||||||
|
chkpt = torch.load(args.initial_checkpoint, map_location='cpu')
|
||||||
|
model.load_state_dict(chkpt['state_dict'])
|
||||||
|
|
||||||
|
# set compute device
|
||||||
|
if type(args.device) == type(None):
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
else:
|
||||||
|
device = torch.device(args.device)
|
||||||
|
|
||||||
|
# push model to device
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
# dataloader
|
||||||
|
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=8)
|
||||||
|
|
||||||
|
# optimizer is introduced to trainable parameters
|
||||||
|
parameters = [p for p in model.parameters() if p.requires_grad]
|
||||||
|
optimizer = torch.optim.Adam(parameters, lr=lr)
|
||||||
|
|
||||||
|
# learning rate scheduler
|
||||||
|
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
|
||||||
|
|
||||||
|
# loss
|
||||||
|
w_l1 = setup['training']['loss']['w_l1']
|
||||||
|
w_lm = setup['training']['loss']['w_lm']
|
||||||
|
w_slm = setup['training']['loss']['w_slm']
|
||||||
|
w_sc = setup['training']['loss']['w_sc']
|
||||||
|
w_logmel = setup['training']['loss']['w_logmel']
|
||||||
|
w_wsc = setup['training']['loss']['w_wsc']
|
||||||
|
w_xcorr = setup['training']['loss']['w_xcorr']
|
||||||
|
w_sxcorr = setup['training']['loss']['w_sxcorr']
|
||||||
|
w_l2 = setup['training']['loss']['w_l2']
|
||||||
|
|
||||||
|
w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
|
||||||
|
|
||||||
|
stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
|
||||||
|
logmelloss = MRLogMelLoss().to(device)
|
||||||
|
|
||||||
|
def xcorr_loss(y_true, y_pred):
|
||||||
|
dims = list(range(1, len(y_true.shape)))
|
||||||
|
|
||||||
|
loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
|
||||||
|
|
||||||
|
return torch.mean(loss)
|
||||||
|
|
||||||
|
def td_l2_norm(y_true, y_pred):
|
||||||
|
dims = list(range(1, len(y_true.shape)))
|
||||||
|
|
||||||
|
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
|
||||||
|
|
||||||
|
return loss.mean()
|
||||||
|
|
||||||
|
def td_l1(y_true, y_pred, pow=0):
|
||||||
|
dims = list(range(1, len(y_true.shape)))
|
||||||
|
tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
|
||||||
|
|
||||||
|
return torch.mean(tmp)
|
||||||
|
|
||||||
|
def criterion(x, y):
|
||||||
|
|
||||||
|
return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
|
||||||
|
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# model checkpoint
|
||||||
|
checkpoint = {
|
||||||
|
'setup' : setup,
|
||||||
|
'state_dict' : model.state_dict(),
|
||||||
|
'loss' : -1
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if not args.no_redirect:
|
||||||
|
print(f"re-directing output to {os.path.join(args.output, output_file)}")
|
||||||
|
sys.stdout = open(os.path.join(args.output, output_file), "w")
|
||||||
|
|
||||||
|
print("summary:")
|
||||||
|
|
||||||
|
print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
|
||||||
|
if hasattr(model, 'flop_count'):
|
||||||
|
print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS")
|
||||||
|
|
||||||
|
if ref is not None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
best_loss = 1e9
|
||||||
|
|
||||||
|
for ep in range(1, epochs + 1):
|
||||||
|
print(f"training epoch {ep}...")
|
||||||
|
new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler)
|
||||||
|
|
||||||
|
|
||||||
|
# save checkpoint
|
||||||
|
checkpoint['state_dict'] = model.state_dict()
|
||||||
|
checkpoint['loss'] = new_loss
|
||||||
|
|
||||||
|
if run_validation:
|
||||||
|
print("running validation...")
|
||||||
|
validation_loss = evaluate(model, criterion, validation_dataloader, device)
|
||||||
|
checkpoint['validation_loss'] = validation_loss
|
||||||
|
|
||||||
|
if validation_loss < best_loss:
|
||||||
|
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth'))
|
||||||
|
best_loss = validation_loss
|
||||||
|
|
||||||
|
if inference_test:
|
||||||
|
print("running inference test...")
|
||||||
|
out = model.process(features, periods).cpu().numpy()
|
||||||
|
wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
|
||||||
|
if ref is not None:
|
||||||
|
mos = pesq.pesq(16000, ref, out, mode='wb')
|
||||||
|
print(f"MOS (PESQ): {mos}")
|
||||||
|
|
||||||
|
|
||||||
|
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
|
||||||
|
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
|
||||||
|
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
print('Done')
|
|
@ -1,31 +1,4 @@
|
||||||
"""
|
|
||||||
/* Copyright (c) 2023 Amazon
|
|
||||||
Written by Jan Buethe */
|
|
||||||
/*
|
|
||||||
Redistribution and use in source and binary forms, with or without
|
|
||||||
modification, are permitted provided that the following conditions
|
|
||||||
are met:
|
|
||||||
|
|
||||||
- Redistributions of source code must retain the above copyright
|
|
||||||
notice, this list of conditions and the following disclaimer.
|
|
||||||
|
|
||||||
- Redistributions in binary form must reproduce the above copyright
|
|
||||||
notice, this list of conditions and the following disclaimer in the
|
|
||||||
documentation and/or other materials provided with the distribution.
|
|
||||||
|
|
||||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
||||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
||||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
||||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
|
||||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
|
||||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
|
||||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
|
||||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
|
||||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
|
||||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
|
||||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
||||||
*/
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _conv1d_flop_count(layer, rate):
|
def _conv1d_flop_count(layer, rate):
|
||||||
return 2 * ((layer.in_channels + 1) * layer.out_channels * rate / layer.stride[0] ) * layer.kernel_size[0]
|
return 2 * ((layer.in_channels + 1) * layer.out_channels * rate / layer.stride[0] ) * layer.kernel_size[0]
|
||||||
|
|
|
@ -1,32 +1,3 @@
|
||||||
"""
|
|
||||||
/* Copyright (c) 2023 Amazon
|
|
||||||
Written by Jan Buethe */
|
|
||||||
/*
|
|
||||||
Redistribution and use in source and binary forms, with or without
|
|
||||||
modification, are permitted provided that the following conditions
|
|
||||||
are met:
|
|
||||||
|
|
||||||
- Redistributions of source code must retain the above copyright
|
|
||||||
notice, this list of conditions and the following disclaimer.
|
|
||||||
|
|
||||||
- Redistributions in binary form must reproduce the above copyright
|
|
||||||
notice, this list of conditions and the following disclaimer in the
|
|
||||||
documentation and/or other materials provided with the distribution.
|
|
||||||
|
|
||||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
||||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
||||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
||||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
|
||||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
|
||||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
|
||||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
|
||||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
|
||||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
|
||||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
|
||||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
||||||
*/
|
|
||||||
"""
|
|
||||||
|
|
||||||
""" module for inspecting models during inference """
|
""" module for inspecting models during inference """
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
100
dnn/torch/osce/utils/layers/noise_shaper.py
Normal file
100
dnn/torch/osce/utils/layers/noise_shaper.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
"""
|
||||||
|
/* Copyright (c) 2023 Amazon
|
||||||
|
Written by Jan Buethe */
|
||||||
|
/*
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions
|
||||||
|
are met:
|
||||||
|
|
||||||
|
- Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
- Redistributions in binary form must reproduce the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer in the
|
||||||
|
documentation and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||||
|
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||||
|
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||||
|
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*/
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from utils.complexity import _conv1d_flop_count
|
||||||
|
|
||||||
|
class NoiseShaper(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
feature_dim,
|
||||||
|
frame_size=160
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
|
||||||
|
feature_dim : int
|
||||||
|
dimension of input features
|
||||||
|
|
||||||
|
frame_size : int
|
||||||
|
frame size
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.feature_dim = feature_dim
|
||||||
|
self.frame_size = frame_size
|
||||||
|
|
||||||
|
# feature transform
|
||||||
|
self.feature_alpha1 = nn.Conv1d(self.feature_dim, frame_size, 2)
|
||||||
|
self.feature_alpha2 = nn.Conv1d(frame_size, frame_size, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def flop_count(self, rate):
|
||||||
|
|
||||||
|
frame_rate = rate / self.frame_size
|
||||||
|
|
||||||
|
shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size
|
||||||
|
|
||||||
|
return shape_flops
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, features):
|
||||||
|
""" creates temporally shaped noise
|
||||||
|
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
features : torch.tensor
|
||||||
|
frame-wise features of shape (batch_size, num_frames, feature_dim)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch_size = features.size(0)
|
||||||
|
num_frames = features.size(1)
|
||||||
|
frame_size = self.frame_size
|
||||||
|
num_samples = num_frames * frame_size
|
||||||
|
|
||||||
|
# feature path
|
||||||
|
f = F.pad(features.permute(0, 2, 1), [1, 0])
|
||||||
|
alpha = F.leaky_relu(self.feature_alpha1(f), 0.2)
|
||||||
|
alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0])))
|
||||||
|
alpha = alpha.permute(0, 2, 1)
|
||||||
|
|
||||||
|
# signal generation
|
||||||
|
y = torch.randn((batch_size, num_frames, frame_size), dtype=features.dtype, device=features.device)
|
||||||
|
y = alpha * y
|
||||||
|
|
||||||
|
return y.reshape(batch_size, 1, num_samples)
|
|
@ -1,3 +1,32 @@
|
||||||
|
"""
|
||||||
|
/* Copyright (c) 2023 Amazon
|
||||||
|
Written by Jan Buethe */
|
||||||
|
/*
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions
|
||||||
|
are met:
|
||||||
|
|
||||||
|
- Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
- Redistributions in binary form must reproduce the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer in the
|
||||||
|
documentation and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||||
|
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||||
|
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||||
|
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*/
|
||||||
|
"""
|
||||||
|
|
||||||
""" This module implements the SILK upsampler from 16kHz to 24 or 48 kHz """
|
""" This module implements the SILK upsampler from 16kHz to 24 or 48 kHz """
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -11,7 +11,8 @@ class TDShaper(nn.Module):
|
||||||
feature_dim,
|
feature_dim,
|
||||||
frame_size=160,
|
frame_size=160,
|
||||||
avg_pool_k=4,
|
avg_pool_k=4,
|
||||||
innovate=False
|
innovate=False,
|
||||||
|
pool_after=False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -39,6 +40,7 @@ class TDShaper(nn.Module):
|
||||||
self.frame_size = frame_size
|
self.frame_size = frame_size
|
||||||
self.avg_pool_k = avg_pool_k
|
self.avg_pool_k = avg_pool_k
|
||||||
self.innovate = innovate
|
self.innovate = innovate
|
||||||
|
self.pool_after = pool_after
|
||||||
|
|
||||||
assert frame_size % avg_pool_k == 0
|
assert frame_size % avg_pool_k == 0
|
||||||
self.env_dim = frame_size // avg_pool_k + 1
|
self.env_dim = frame_size // avg_pool_k + 1
|
||||||
|
@ -71,8 +73,12 @@ class TDShaper(nn.Module):
|
||||||
def envelope_transform(self, x):
|
def envelope_transform(self, x):
|
||||||
|
|
||||||
x = torch.abs(x)
|
x = torch.abs(x)
|
||||||
x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k)
|
if self.pool_after:
|
||||||
x = torch.log(x + .5**16)
|
x = torch.log(x + .5**16)
|
||||||
|
x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k)
|
||||||
|
else:
|
||||||
|
x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k)
|
||||||
|
x = torch.log(x + .5**16)
|
||||||
|
|
||||||
x = x.reshape(x.size(0), -1, self.env_dim - 1)
|
x = x.reshape(x.size(0), -1, self.env_dim - 1)
|
||||||
avg_x = torch.mean(x, -1, keepdim=True)
|
avg_x = torch.mean(x, -1, keepdim=True)
|
||||||
|
|
112
dnn/torch/osce/utils/lpcnet_features.py
Normal file
112
dnn/torch/osce/utils/lpcnet_features.py
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def load_lpcnet_features(feature_file, version=2):
|
||||||
|
if version == 2:
|
||||||
|
layout = {
|
||||||
|
'cepstrum': [0,18],
|
||||||
|
'periods': [18, 19],
|
||||||
|
'pitch_corr': [19, 20],
|
||||||
|
'lpc': [20, 36]
|
||||||
|
}
|
||||||
|
frame_length = 36
|
||||||
|
|
||||||
|
elif version == 1:
|
||||||
|
layout = {
|
||||||
|
'cepstrum': [0,18],
|
||||||
|
'periods': [36, 37],
|
||||||
|
'pitch_corr': [37, 38],
|
||||||
|
'lpc': [39, 55],
|
||||||
|
}
|
||||||
|
frame_length = 55
|
||||||
|
else:
|
||||||
|
raise ValueError(f'unknown feature version: {version}')
|
||||||
|
|
||||||
|
|
||||||
|
raw_features = torch.from_numpy(np.fromfile(feature_file, dtype='float32'))
|
||||||
|
raw_features = raw_features.reshape((-1, frame_length))
|
||||||
|
|
||||||
|
features = torch.cat(
|
||||||
|
[
|
||||||
|
raw_features[:, layout['cepstrum'][0] : layout['cepstrum'][1]],
|
||||||
|
raw_features[:, layout['pitch_corr'][0] : layout['pitch_corr'][1]]
|
||||||
|
],
|
||||||
|
dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
lpcs = raw_features[:, layout['lpc'][0] : layout['lpc'][1]]
|
||||||
|
periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long()
|
||||||
|
|
||||||
|
return {'features' : features, 'periods' : periods, 'lpcs' : lpcs}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def create_new_data(signal_path, reference_data_path, new_data_path, offset=320, preemph_factor=0.85):
|
||||||
|
ref_data = np.memmap(reference_data_path, dtype=np.int16)
|
||||||
|
signal = np.memmap(signal_path, dtype=np.int16)
|
||||||
|
|
||||||
|
signal_preemph_path = os.path.splitext(signal_path)[0] + '_preemph.raw'
|
||||||
|
signal_preemph = np.memmap(signal_preemph_path, dtype=np.int16, mode='write', shape=signal.shape)
|
||||||
|
|
||||||
|
|
||||||
|
assert len(signal) % 160 == 0
|
||||||
|
num_frames = len(signal) // 160
|
||||||
|
mem = np.zeros(1)
|
||||||
|
for fr in range(len(signal)//160):
|
||||||
|
signal_preemph[fr * 160 : (fr + 1) * 160] = np.convolve(np.concatenate((mem, signal[fr * 160 : (fr + 1) * 160])), [1, -preemph_factor], mode='valid')
|
||||||
|
mem = signal[(fr + 1) * 160 - 1 : (fr + 1) * 160]
|
||||||
|
|
||||||
|
new_data = np.memmap(new_data_path, dtype=np.int16, mode='write', shape=ref_data.shape)
|
||||||
|
|
||||||
|
new_data[:] = 0
|
||||||
|
N = len(signal) - offset
|
||||||
|
new_data[1 : 2*N + 1: 2] = signal_preemph[offset:]
|
||||||
|
new_data[2 : 2*N + 2: 2] = signal_preemph[offset:]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_warpq_scores(output_file):
|
||||||
|
""" extracts warpq scores from output file """
|
||||||
|
|
||||||
|
with open(output_file, "r") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
scores = [float(line.split("WARP-Q score:")[-1]) for line in lines if line.startswith("WARP-Q score:")]
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def parse_stats_file(file):
|
||||||
|
|
||||||
|
with open(file, "r") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
mean = float(lines[0].split(":")[-1])
|
||||||
|
bt_mean = float(lines[1].split(":")[-1])
|
||||||
|
top_mean = float(lines[2].split(":")[-1])
|
||||||
|
|
||||||
|
return mean, bt_mean, top_mean
|
||||||
|
|
||||||
|
def collect_test_stats(test_folder):
|
||||||
|
""" collects statistics for all discovered metrics from test folder """
|
||||||
|
|
||||||
|
metrics = {'pesq', 'warpq', 'pitch_error', 'voicing_error'}
|
||||||
|
|
||||||
|
results = dict()
|
||||||
|
|
||||||
|
content = os.listdir(test_folder)
|
||||||
|
|
||||||
|
stats_files = [file for file in content if file.startswith('stats_')]
|
||||||
|
|
||||||
|
for file in stats_files:
|
||||||
|
metric = file[len("stats_") : -len(".txt")]
|
||||||
|
|
||||||
|
if metric not in metrics:
|
||||||
|
print(f"warning: unknown metric {metric}")
|
||||||
|
|
||||||
|
mean, bt_mean, top_mean = parse_stats_file(os.path.join(test_folder, file))
|
||||||
|
|
||||||
|
results[metric] = [mean, bt_mean, top_mean]
|
||||||
|
|
||||||
|
return results
|
|
@ -39,4 +39,27 @@ def count_parameters(model, verbose=False):
|
||||||
|
|
||||||
total += count
|
total += count
|
||||||
|
|
||||||
return total
|
return total
|
||||||
|
|
||||||
|
|
||||||
|
def retain_grads(module):
|
||||||
|
for p in module.parameters():
|
||||||
|
if p.requires_grad:
|
||||||
|
p.retain_grad()
|
||||||
|
|
||||||
|
def get_grad_norm(module, p=2):
|
||||||
|
norm = 0
|
||||||
|
for param in module.parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
norm = norm + (torch.abs(param.grad) ** p).sum()
|
||||||
|
|
||||||
|
return norm ** (1/p)
|
||||||
|
|
||||||
|
def create_weights(s_real, s_gen, alpha):
|
||||||
|
weights = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for sr, sg in zip(s_real, s_gen):
|
||||||
|
weight = torch.exp(alpha * (sr[-1] - sg[-1]))
|
||||||
|
weights.append(weight)
|
||||||
|
|
||||||
|
return weights
|
|
@ -27,7 +27,6 @@
|
||||||
*/
|
*/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
|
@ -30,6 +30,7 @@
|
||||||
import math as m
|
import math as m
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy
|
import scipy
|
||||||
|
import torch
|
||||||
|
|
||||||
def erb(f):
|
def erb(f):
|
||||||
return 24.7 * (4.37 * f + 1)
|
return 24.7 * (4.37 * f + 1)
|
||||||
|
@ -49,6 +50,20 @@ scale_dict = {
|
||||||
'erb': [erb, inv_erb]
|
'erb': [erb, inv_erb]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def gen_filterbank(N, Fs=16000, keep_size=False):
|
||||||
|
in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
|
||||||
|
M = N + 1 if keep_size else N
|
||||||
|
out_freq = (np.arange(M, dtype='float32')/N*Fs/2)[:,None]
|
||||||
|
#ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
|
||||||
|
ERB_N = 24.7 + .108*in_freq
|
||||||
|
delta = np.abs(in_freq-out_freq)/ERB_N
|
||||||
|
center = (delta<.5).astype('float32')
|
||||||
|
R = -12*center*delta**2 + (1-center)*(3-12*delta)
|
||||||
|
RE = 10.**(R/10.)
|
||||||
|
norm = np.sum(RE, axis=1)
|
||||||
|
RE = RE/norm[:, np.newaxis]
|
||||||
|
return torch.from_numpy(RE)
|
||||||
|
|
||||||
def create_filter_bank(num_bands, n_fft=320, fs=16000, scale='bark', round_center_bins=False, return_upper=False, normalize=False):
|
def create_filter_bank(num_bands, n_fft=320, fs=16000, scale='bark', round_center_bins=False, return_upper=False, normalize=False):
|
||||||
|
|
||||||
f0 = 0
|
f0 = 0
|
||||||
|
|
|
@ -140,8 +140,196 @@ nolace_setup = {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nolace_setup_adv = {
|
||||||
|
'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training',
|
||||||
|
'model': {
|
||||||
|
'name': 'nolace',
|
||||||
|
'args': [],
|
||||||
|
'kwargs': {
|
||||||
|
'avg_pool_k': 4,
|
||||||
|
'comb_gain_limit_db': 10,
|
||||||
|
'cond_dim': 256,
|
||||||
|
'conv_gain_limits_db': [-12, 12],
|
||||||
|
'global_gain_limits_db': [-6, 6],
|
||||||
|
'hidden_feature_dim': 96,
|
||||||
|
'kernel_size': 15,
|
||||||
|
'num_features': 93,
|
||||||
|
'numbits_embedding_dim': 8,
|
||||||
|
'numbits_range': [50, 650],
|
||||||
|
'partial_lookahead': True,
|
||||||
|
'pitch_embedding_dim': 64,
|
||||||
|
'pitch_max': 300,
|
||||||
|
'preemph': 0.85,
|
||||||
|
'skip': 91
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'data': {
|
||||||
|
'frames_per_sample': 100,
|
||||||
|
'no_pitch_value': 7,
|
||||||
|
'preemph': 0.85,
|
||||||
|
'skip': 91,
|
||||||
|
'pitch_hangover': 8,
|
||||||
|
'acorr_radius': 2,
|
||||||
|
'num_bands_clean_spec': 64,
|
||||||
|
'num_bands_noisy_spec': 18,
|
||||||
|
'noisy_spec_scale': 'opus',
|
||||||
|
'pitch_hangover': 8,
|
||||||
|
},
|
||||||
|
'discriminator': {
|
||||||
|
'args': [],
|
||||||
|
'kwargs': {
|
||||||
|
'architecture': 'free',
|
||||||
|
'design': 'f_down',
|
||||||
|
'fft_sizes_16k': [
|
||||||
|
64,
|
||||||
|
128,
|
||||||
|
256,
|
||||||
|
512,
|
||||||
|
1024,
|
||||||
|
2048,
|
||||||
|
],
|
||||||
|
'freq_roi': [0, 7400],
|
||||||
|
'fs': 16000,
|
||||||
|
'max_channels': 256,
|
||||||
|
'noise_gain': 0.0,
|
||||||
|
},
|
||||||
|
'name': 'fdmresdisc',
|
||||||
|
},
|
||||||
|
'training': {
|
||||||
|
'adv_target': 'target_orig',
|
||||||
|
'batch_size': 64,
|
||||||
|
'epochs': 50,
|
||||||
|
'gen_lr_reduction': 1,
|
||||||
|
'lambda_feat': 1.0,
|
||||||
|
'lambda_reg': 0.6,
|
||||||
|
'loss': {
|
||||||
|
'w_l1': 0,
|
||||||
|
'w_l2': 10,
|
||||||
|
'w_lm': 0,
|
||||||
|
'w_logmel': 0,
|
||||||
|
'w_sc': 0,
|
||||||
|
'w_slm': 20,
|
||||||
|
'w_sxcorr': 1,
|
||||||
|
'w_wsc': 0,
|
||||||
|
'w_xcorr': 0,
|
||||||
|
},
|
||||||
|
'lr': 0.0001,
|
||||||
|
'lr_decay_factor': 2.5e-09,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
lavoce_setup = {
|
||||||
|
'data': {
|
||||||
|
'frames_per_sample': 100,
|
||||||
|
'target': 'signal'
|
||||||
|
},
|
||||||
|
'dataset': '/local/datasets/lpcnet_large/training',
|
||||||
|
'model': {
|
||||||
|
'args': [],
|
||||||
|
'kwargs': {
|
||||||
|
'comb_gain_limit_db': 10,
|
||||||
|
'cond_dim': 256,
|
||||||
|
'conv_gain_limits_db': [-12, 12],
|
||||||
|
'global_gain_limits_db': [-6, 6],
|
||||||
|
'kernel_size': 15,
|
||||||
|
'num_features': 19,
|
||||||
|
'pitch_embedding_dim': 64,
|
||||||
|
'pitch_max': 300,
|
||||||
|
'preemph': 0.85,
|
||||||
|
'pulses': True
|
||||||
|
},
|
||||||
|
'name': 'lavoce'
|
||||||
|
},
|
||||||
|
'training': {
|
||||||
|
'batch_size': 256,
|
||||||
|
'epochs': 50,
|
||||||
|
'loss': {
|
||||||
|
'w_l1': 0,
|
||||||
|
'w_l2': 0,
|
||||||
|
'w_lm': 0,
|
||||||
|
'w_logmel': 0,
|
||||||
|
'w_sc': 0,
|
||||||
|
'w_slm': 2,
|
||||||
|
'w_sxcorr': 1,
|
||||||
|
'w_wsc': 0,
|
||||||
|
'w_xcorr': 0
|
||||||
|
},
|
||||||
|
'lr': 0.0005,
|
||||||
|
'lr_decay_factor': 2.5e-05
|
||||||
|
},
|
||||||
|
'validation_dataset': '/local/datasets/lpcnet_large/validation'
|
||||||
|
}
|
||||||
|
|
||||||
|
lavoce_setup_adv = {
|
||||||
|
'data': {
|
||||||
|
'frames_per_sample': 100,
|
||||||
|
'target': 'signal'
|
||||||
|
},
|
||||||
|
'dataset': '/local/datasets/lpcnet_large/training',
|
||||||
|
'discriminator': {
|
||||||
|
'args': [],
|
||||||
|
'kwargs': {
|
||||||
|
'architecture': 'free',
|
||||||
|
'design': 'f_down',
|
||||||
|
'fft_sizes_16k': [
|
||||||
|
64,
|
||||||
|
128,
|
||||||
|
256,
|
||||||
|
512,
|
||||||
|
1024,
|
||||||
|
2048,
|
||||||
|
],
|
||||||
|
'freq_roi': [0, 7400],
|
||||||
|
'fs': 16000,
|
||||||
|
'max_channels': 256,
|
||||||
|
'noise_gain': 0.0,
|
||||||
|
},
|
||||||
|
'name': 'fdmresdisc',
|
||||||
|
},
|
||||||
|
'model': {
|
||||||
|
'args': [],
|
||||||
|
'kwargs': {
|
||||||
|
'comb_gain_limit_db': 10,
|
||||||
|
'cond_dim': 256,
|
||||||
|
'conv_gain_limits_db': [-12, 12],
|
||||||
|
'global_gain_limits_db': [-6, 6],
|
||||||
|
'kernel_size': 15,
|
||||||
|
'num_features': 19,
|
||||||
|
'pitch_embedding_dim': 64,
|
||||||
|
'pitch_max': 300,
|
||||||
|
'preemph': 0.85,
|
||||||
|
'pulses': True
|
||||||
|
},
|
||||||
|
'name': 'lavoce'
|
||||||
|
},
|
||||||
|
'training': {
|
||||||
|
'batch_size': 64,
|
||||||
|
'epochs': 50,
|
||||||
|
'gen_lr_reduction': 1,
|
||||||
|
'lambda_feat': 1.0,
|
||||||
|
'lambda_reg': 0.6,
|
||||||
|
'loss': {
|
||||||
|
'w_l1': 0,
|
||||||
|
'w_l2': 0,
|
||||||
|
'w_lm': 0,
|
||||||
|
'w_logmel': 0,
|
||||||
|
'w_sc': 0,
|
||||||
|
'w_slm': 2,
|
||||||
|
'w_sxcorr': 1,
|
||||||
|
'w_wsc': 0,
|
||||||
|
'w_xcorr': 0
|
||||||
|
},
|
||||||
|
'lr': 0.0001,
|
||||||
|
'lr_decay_factor': 2.5e-09
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
setup_dict = {
|
setup_dict = {
|
||||||
'lace': lace_setup,
|
'lace': lace_setup,
|
||||||
'nolace': nolace_setup
|
'nolace': nolace_setup,
|
||||||
|
'nolace_adv': nolace_setup_adv,
|
||||||
|
'lavoce': lavoce_setup,
|
||||||
|
'lavoce_adv': lavoce_setup_adv
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue