Opus ng lace

This commit is contained in:
Jan Buethe 2023-06-30 21:15:56 +00:00 committed by Jean-Marc Valin
parent 178672ed18
commit 105e1d83fa
24 changed files with 2937 additions and 0 deletions

4
dnn/torch/osce/README.md Normal file
View file

@ -0,0 +1,4 @@
# Opus Speech Coding Enhancement
This folder hosts models for enhancing SILK. See related Opus repo https://gitlab.xiph.org/xiph/opus/-/tree/exp-neural-silk-enhancement
for feature generation.

View file

@ -0,0 +1,30 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
from .silk_enhancement_set import SilkEnhancementSet

View file

@ -0,0 +1,140 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
from torch.utils.data import Dataset
import numpy as np
from utils.silk_features import silk_feature_factory
from utils.pitch import hangover, calculate_acorr_window
class SilkEnhancementSet(Dataset):
def __init__(self,
path,
frames_per_sample=100,
no_pitch_value=256,
preemph=0.85,
skip=91,
acorr_radius=2,
pitch_hangover=8,
num_bands_clean_spec=64,
num_bands_noisy_spec=18,
noisy_spec_scale='opus',
noisy_apply_dct=True,
add_offset=False,
add_double_lag_acorr=False
):
assert frames_per_sample % 4 == 0
self.frame_size = 80
self.frames_per_sample = frames_per_sample
self.no_pitch_value = no_pitch_value
self.preemph = preemph
self.skip = skip
self.acorr_radius = acorr_radius
self.pitch_hangover = pitch_hangover
self.num_bands_clean_spec = num_bands_clean_spec
self.num_bands_noisy_spec = num_bands_noisy_spec
self.noisy_spec_scale = noisy_spec_scale
self.add_double_lag_acorr = add_double_lag_acorr
self.lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16)
self.ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5)
self.periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
self.gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
self.num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32)
self.num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32)
self.offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32)
self.clean_signal = np.fromfile(os.path.join(path, 'clean_hp.s16'), dtype=np.int16)
self.coded_signal = np.fromfile(os.path.join(path, 'coded.s16'), dtype=np.int16)
self.create_features = silk_feature_factory(no_pitch_value,
acorr_radius,
pitch_hangover,
num_bands_clean_spec,
num_bands_noisy_spec,
noisy_spec_scale,
noisy_apply_dct,
add_offset,
add_double_lag_acorr)
self.history_len = 700 if add_double_lag_acorr else 350
# discard some frames to have enough signal history
self.skip_frames = 4 * ((skip + self.history_len + 319) // 320 + 2)
num_frames = self.clean_signal.shape[0] // 80 - self.skip_frames
self.len = num_frames // frames_per_sample
def __len__(self):
return self.len
def __getitem__(self, index):
frame_start = self.frames_per_sample * index + self.skip_frames
frame_stop = frame_start + self.frames_per_sample
signal_start = frame_start * self.frame_size - self.skip
signal_stop = frame_stop * self.frame_size - self.skip
clean_signal = self.clean_signal[signal_start : signal_stop].astype(np.float32) / 2**15
coded_signal = self.coded_signal[signal_start : signal_stop].astype(np.float32) / 2**15
coded_signal_history = self.coded_signal[signal_start - self.history_len : signal_start].astype(np.float32) / 2**15
features, periods = self.create_features(
coded_signal,
coded_signal_history,
self.lpcs[frame_start : frame_stop],
self.gains[frame_start : frame_stop],
self.ltps[frame_start : frame_stop],
self.periods[frame_start : frame_stop],
self.offsets[frame_start : frame_stop]
)
if self.preemph > 0:
clean_signal[1:] -= self.preemph * clean_signal[: -1]
coded_signal[1:] -= self.preemph * coded_signal[: -1]
num_bits = np.repeat(self.num_bits[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
num_bits_smooth = np.repeat(self.num_bits_smooth[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
numbits = np.concatenate((num_bits, num_bits_smooth), axis=-1)
return {
'features' : features,
'periods' : periods.astype(np.int64),
'target' : clean_signal.astype(np.float32),
'signals' : coded_signal.reshape(-1, 1).astype(np.float32),
'numbits' : numbits.astype(np.float32)
}

View file

@ -0,0 +1,101 @@
import torch
from tqdm import tqdm
import sys
def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
model.to(device)
model.train()
running_loss = 0
previous_running_loss = 0
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# set gradients to zero
optimizer.zero_grad()
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target']
# calculate model output
output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
# calculate loss
if isinstance(output, list):
loss = torch.zeros(1, device=device)
for y in output:
loss = loss + criterion(target, y.squeeze(1))
loss = loss / len(output)
else:
loss = criterion(target, output.squeeze(1))
# calculate gradients
loss.backward()
# update weights
optimizer.step()
# update learning rate
scheduler.step()
# update running loss
running_loss += float(loss.cpu())
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
previous_running_loss = running_loss
running_loss /= len(dataloader)
return running_loss
def evaluate(model, criterion, dataloader, device, log_interval=10):
model.to(device)
model.eval()
running_loss = 0
previous_running_loss = 0
with torch.no_grad():
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target']
# calculate model output
output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
# calculate loss
loss = criterion(target, output.squeeze(1))
# update running loss
running_loss += float(loss.cpu())
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
previous_running_loss = running_loss
running_loss /= len(dataloader)
return running_loss

View file

@ -0,0 +1,277 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
"""STFT-based Loss modules."""
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
import torchaudio
def get_window(win_name, win_length, *args, **kwargs):
window_dict = {
'bartlett_window' : torch.bartlett_window,
'blackman_window' : torch.blackman_window,
'hamming_window' : torch.hamming_window,
'hann_window' : torch.hann_window,
'kaiser_window' : torch.kaiser_window
}
if not win_name in window_dict:
raise ValueError()
return window_dict[win_name](win_length, *args, **kwargs)
def stft(x, fft_size, hop_size, win_length, window):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x (Tensor): Input signal tensor (B, T).
fft_size (int): FFT size.
hop_size (int): Hop size.
win_length (int): Window length.
window (str): Window function type.
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
win = get_window(window, win_length).to(x.device)
x_stft = torch.stft(x, fft_size, hop_size, win_length, win, return_complex=True)
return torch.clamp(torch.abs(x_stft), min=1e-7)
def spectral_convergence_loss(Y_true, Y_pred):
dims=list(range(1, len(Y_pred.shape)))
return torch.mean(torch.norm(torch.abs(Y_true) - torch.abs(Y_pred), p="fro", dim=dims) / (torch.norm(Y_pred, p="fro", dim=dims) + 1e-6))
def log_magnitude_loss(Y_true, Y_pred):
Y_true_log_abs = torch.log(torch.abs(Y_true) + 1e-15)
Y_pred_log_abs = torch.log(torch.abs(Y_pred) + 1e-15)
return torch.mean(torch.abs(Y_true_log_abs - Y_pred_log_abs))
def spectral_xcorr_loss(Y_true, Y_pred):
Y_true = Y_true.abs()
Y_pred = Y_pred.abs()
dims=list(range(1, len(Y_pred.shape)))
xcorr = torch.sum(Y_true * Y_pred, dim=dims) / torch.sqrt(torch.sum(Y_true ** 2, dim=dims) * torch.sum(Y_pred ** 2, dim=dims) + 1e-9)
return 1 - xcorr.mean()
class MRLogMelLoss(nn.Module):
def __init__(self,
fft_sizes=[512, 256, 128, 64],
overlap=0.5,
fs=16000,
n_mels=18
):
self.fft_sizes = fft_sizes
self.overlap = overlap
self.fs = fs
self.n_mels = n_mels
super().__init__()
self.mel_specs = []
for fft_size in fft_sizes:
hop_size = int(round(fft_size * (1 - self.overlap)))
n_mels = self.n_mels
if fft_size < 128:
n_mels //= 2
self.mel_specs.append(torchaudio.transforms.MelSpectrogram(fs, fft_size, hop_length=hop_size, n_mels=n_mels))
for i, mel_spec in enumerate(self.mel_specs):
self.add_module(f'mel_spec_{i+1}', mel_spec)
def forward(self, y_true, y_pred):
loss = torch.zeros(1, device=y_true.device)
for mel_spec in self.mel_specs:
Y_true = mel_spec(y_true)
Y_pred = mel_spec(y_pred)
loss = loss + log_magnitude_loss(Y_true, Y_pred)
loss = loss / len(self.mel_specs)
return loss
def create_weight_matrix(num_bins, bins_per_band=10):
m = torch.zeros((num_bins, num_bins), dtype=torch.float32)
r0 = bins_per_band // 2
r1 = bins_per_band - r0
for i in range(num_bins):
i0 = max(i - r0, 0)
j0 = min(i + r1, num_bins)
m[i, i0: j0] += 1
if i < r0:
m[i, :r0 - i] += 1
if i > num_bins - r1:
m[i, num_bins - r1 - i:] += 1
return m / bins_per_band
def weighted_spectral_convergence(Y_true, Y_pred, w):
# calculate sfm based weights
logY = torch.log(torch.abs(Y_true) + 1e-9)
Y = torch.abs(Y_true)
avg_logY = torch.matmul(logY.transpose(1, 2), w)
avg_Y = torch.matmul(Y.transpose(1, 2), w)
sfm = torch.exp(avg_logY) / (avg_Y + 1e-9)
weight = (torch.relu(1 - sfm) ** .5).transpose(1, 2)
loss = torch.mean(
torch.mean(weight * torch.abs(torch.abs(Y_true) - torch.abs(Y_pred)), dim=[1, 2])
/ (torch.mean( weight * torch.abs(Y_true), dim=[1, 2]) + 1e-9)
)
return loss
def gen_filterbank(N, Fs=16000):
in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
out_freq = (np.arange(N, dtype='float32')/N*Fs/2)[:,None]
#ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
ERB_N = 24.7 + .108*in_freq
delta = np.abs(in_freq-out_freq)/ERB_N
center = (delta<.5).astype('float32')
R = -12*center*delta**2 + (1-center)*(3-12*delta)
RE = 10.**(R/10.)
norm = np.sum(RE, axis=1)
RE = RE/norm[:, np.newaxis]
return torch.from_numpy(RE)
def smooth_log_mag(Y_true, Y_pred, filterbank):
Y_true_smooth = torch.matmul(filterbank, torch.abs(Y_true))
Y_pred_smooth = torch.matmul(filterbank, torch.abs(Y_pred))
loss = torch.abs(
torch.log(Y_true_smooth + 1e-9) - torch.log(Y_pred_smooth + 1e-9)
)
loss = loss.mean()
return loss
class MRSTFTLoss(nn.Module):
def __init__(self,
fft_sizes=[2048, 1024, 512, 256, 128, 64],
overlap=0.5,
window='hann_window',
fs=16000,
log_mag_weight=1,
sc_weight=0,
wsc_weight=0,
smooth_log_mag_weight=0,
sxcorr_weight=0):
super().__init__()
self.fft_sizes = fft_sizes
self.overlap = overlap
self.window = window
self.log_mag_weight = log_mag_weight
self.sc_weight = sc_weight
self.wsc_weight = wsc_weight
self.smooth_log_mag_weight = smooth_log_mag_weight
self.sxcorr_weight = sxcorr_weight
self.fs = fs
# weights for SFM weighted spectral convergence loss
self.wsc_weights = torch.nn.ParameterDict()
for fft_size in fft_sizes:
width = min(11, int(1000 * fft_size / self.fs + .5))
width += width % 2
self.wsc_weights[str(fft_size)] = torch.nn.Parameter(
create_weight_matrix(fft_size // 2 + 1, width),
requires_grad=False
)
# filterbanks for smooth log magnitude loss
self.filterbanks = torch.nn.ParameterDict()
for fft_size in fft_sizes:
self.filterbanks[str(fft_size)] = torch.nn.Parameter(
gen_filterbank(fft_size//2),
requires_grad=False
)
def __call__(self, y_true, y_pred):
lm_loss = torch.zeros(1, device=y_true.device)
sc_loss = torch.zeros(1, device=y_true.device)
wsc_loss = torch.zeros(1, device=y_true.device)
slm_loss = torch.zeros(1, device=y_true.device)
sxcorr_loss = torch.zeros(1, device=y_true.device)
for fft_size in self.fft_sizes:
hop_size = int(round(fft_size * (1 - self.overlap)))
win_size = fft_size
Y_true = stft(y_true, fft_size, hop_size, win_size, self.window)
Y_pred = stft(y_pred, fft_size, hop_size, win_size, self.window)
if self.log_mag_weight > 0:
lm_loss = lm_loss + log_magnitude_loss(Y_true, Y_pred)
if self.sc_weight > 0:
sc_loss = sc_loss + spectral_convergence_loss(Y_true, Y_pred)
if self.wsc_weight > 0:
wsc_loss = wsc_loss + weighted_spectral_convergence(Y_true, Y_pred, self.wsc_weights[str(fft_size)])
if self.smooth_log_mag_weight > 0:
slm_loss = slm_loss + smooth_log_mag(Y_true, Y_pred, self.filterbanks[str(fft_size)])
if self.sxcorr_weight > 0:
sxcorr_loss = sxcorr_loss + spectral_xcorr_loss(Y_true, Y_pred)
total_loss = (self.log_mag_weight * lm_loss + self.sc_weight * sc_loss
+ self.wsc_weight * wsc_loss + self.smooth_log_mag_weight * slm_loss
+ self.sxcorr_weight * sxcorr_loss) / len(self.fft_sizes)
return total_loss

View file

@ -0,0 +1,56 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import argparse
import yaml
from utils.templates import setup_dict
parser = argparse.ArgumentParser()
parser.add_argument('name', type=str, help='name of default setup file')
parser.add_argument('--model', choices=['lace'], help='model name', default='lace')
parser.add_argument('--path2dataset', type=str, help='dataset path', default=None)
args = parser.parse_args()
setup = setup_dict[args.model]
# update dataset if given
if type(args.path2dataset) != type(None):
setup['dataset'] = args.path2dataset
name = args.name
if not name.endswith('.yml'):
name += '.yml'
if __name__ == '__main__':
with open(name, 'w') as f:
f.write(yaml.dump(setup))

View file

@ -0,0 +1,36 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
from .lace import LACE
model_dict = {
'lace': LACE
}

View file

@ -0,0 +1,176 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
from models.nns_base import NNSBase
from models.silk_feature_net_pl import SilkFeatureNetPL
from models.silk_feature_net import SilkFeatureNet
from .scale_embedding import ScaleEmbedding
class LACE(NNSBase):
""" Linear-Adaptive Coding Enhancer """
FRAME_SIZE=80
def __init__(self,
num_features=47,
pitch_embedding_dim=64,
cond_dim=256,
pitch_max=257,
kernel_size=15,
preemph=0.85,
skip=91,
comb_gain_limit_db=-6,
global_gain_limits_db=[-6, 6],
conv_gain_limits_db=[-6, 6],
numbits_range=[50, 650],
numbits_embedding_dim=8,
hidden_feature_dim=64,
partial_lookahead=True,
norm_p=2):
super().__init__(skip=skip, preemph=preemph)
self.num_features = num_features
self.cond_dim = cond_dim
self.pitch_max = pitch_max
self.pitch_embedding_dim = pitch_embedding_dim
self.kernel_size = kernel_size
self.preemph = preemph
self.skip = skip
self.numbits_range = numbits_range
self.numbits_embedding_dim = numbits_embedding_dim
self.hidden_feature_dim = hidden_feature_dim
self.partial_lookahead = partial_lookahead
# pitch embedding
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
# numbits embedding
self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)
# feature net
if partial_lookahead:
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim)
else:
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
# comb filters
left_pad = self.kernel_size // 2
right_pad = self.kernel_size - 1 - left_pad
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
# spectral shaping
self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
def flop_count(self, rate=16000, verbose=False):
frame_rate = rate / self.FRAME_SIZE
# feature net
feature_net_flops = self.feature_net.flop_count(frame_rate)
comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
af_flops = self.af1.flop_count(rate)
if verbose:
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
print(f"comb filters: {comb_flops / 1e6} MFLOPS")
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
return feature_net_flops + comb_flops + af_flops
def forward(self, x, features, periods, numbits, debug=False):
periods = periods.squeeze(-1)
pitch_embedding = self.pitch_embedding(periods)
numbits_embedding = self.numbits_embedding(numbits).flatten(2)
full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
cf = self.feature_net(full_features)
y = self.cf1(x, cf, periods, debug=debug)
y = self.cf2(y, cf, periods, debug=debug)
y = self.af1(y, cf, debug=debug)
return y
def get_impulse_responses(self, features, periods, numbits):
""" generates impoulse responses on frame centers (input without batch dimension) """
num_frames = features.size(0)
batch_size = 32
max_len = 2 * (self.pitch_max + self.kernel_size) + 10
# spread out some pulses
x = np.zeros((batch_size, 1, num_frames * self.FRAME_SIZE))
for b in range(batch_size):
x[b, :, self.FRAME_SIZE // 2 + b * self.FRAME_SIZE :: batch_size * self.FRAME_SIZE] = 1
# prepare input
x = torch.from_numpy(x).float().to(features.device)
features = torch.repeat_interleave(features.unsqueeze(0), batch_size, 0)
periods = torch.repeat_interleave(periods.unsqueeze(0), batch_size, 0)
numbits = torch.repeat_interleave(numbits.unsqueeze(0), batch_size, 0)
# run network
with torch.no_grad():
periods = periods.squeeze(-1)
pitch_embedding = self.pitch_embedding(periods)
numbits_embedding = self.numbits_embedding(numbits).flatten(2)
full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
cf = self.feature_net(full_features)
y = self.cf1(x, cf, periods, debug=False)
y = self.cf2(y, cf, periods, debug=False)
y = self.af1(y, cf, debug=False)
# collect responses
y = y.detach().squeeze().cpu().numpy()
cut_frames = (max_len + self.FRAME_SIZE - 1) // self.FRAME_SIZE
num_responses = num_frames - cut_frames
responses = np.zeros((num_responses, max_len))
for i in range(num_responses):
b = i % batch_size
start = self.FRAME_SIZE // 2 + i * self.FRAME_SIZE
stop = start + max_len
responses[i, :] = y[b, start:stop]
return responses

View file

@ -0,0 +1,69 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
class NNSBase(nn.Module):
def __init__(self, skip=91, preemph=0.85):
super().__init__()
self.skip = skip
self.preemph = preemph
def process(self, sig, features, periods, numbits, debug=False):
self.eval()
has_numbits = 'numbits' in self.forward.__code__.co_varnames
device = next(iter(self.parameters())).device
with torch.no_grad():
# run model
x = sig.view(1, 1, -1).to(device)
f = features.unsqueeze(0).to(device)
p = periods.unsqueeze(0).to(device)
n = numbits.unsqueeze(0).to(device)
if has_numbits:
y = self.forward(x, f, p, n, debug=debug).squeeze()
else:
y = self.forward(x, f, p, debug=debug).squeeze()
# deemphasis
if self.preemph > 0:
for i in range(len(y) - 1):
y[i + 1] += self.preemph * y[i]
# delay compensation
y = torch.cat((y[self.skip:], torch.zeros(self.skip, dtype=y.dtype, device=y.device)))
out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short()
return out

View file

@ -0,0 +1,68 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import math as m
import torch
from torch import nn
class ScaleEmbedding(nn.Module):
def __init__(self,
dim,
min_val,
max_val,
logscale=False):
super().__init__()
if min_val >= max_val:
raise ValueError('min_val must be smaller than max_val')
if min_val <= 0 and logscale:
raise ValueError('min_val must be positive when logscale is true')
self.dim = dim
self.logscale = logscale
self.min_val = min_val
self.max_val = max_val
if logscale:
self.min_val = m.log(self.min_val)
self.max_val = m.log(self.max_val)
self.offset = (self.min_val + self.max_val) / 2
self.scale_factors = nn.Parameter(
torch.arange(1, dim+1, dtype=torch.float32) * torch.pi / (self.max_val - self.min_val)
)
def forward(self, x):
if self.logscale: x = torch.log(x)
x = torch.clip(x, self.min_val, self.max_val) - self.offset
return torch.sin(x.unsqueeze(-1) * self.scale_factors - 0.5)

View file

@ -0,0 +1,86 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
from utils.complexity import _conv1d_flop_count
class SilkFeatureNet(nn.Module):
def __init__(self,
feature_dim=47,
num_channels=256,
lookahead=False):
super(SilkFeatureNet, self).__init__()
self.feature_dim = feature_dim
self.num_channels = num_channels
self.lookahead = lookahead
self.conv1 = nn.Conv1d(feature_dim, num_channels, 3)
self.conv2 = nn.Conv1d(num_channels, num_channels, 3, dilation=2)
self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
def flop_count(self, rate=200):
count = 0
for conv in self.conv1, self.conv2:
count += _conv1d_flop_count(conv, rate)
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
return count
def forward(self, features, state=None):
""" features shape: (batch_size, num_frames, feature_dim) """
batch_size = features.size(0)
if state is None:
state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
features = features.permute(0, 2, 1)
if self.lookahead:
c = torch.tanh(self.conv1(F.pad(features, [1, 1])))
c = torch.tanh(self.conv2(F.pad(c, [2, 2])))
else:
c = torch.tanh(self.conv1(F.pad(features, [2, 0])))
c = torch.tanh(self.conv2(F.pad(c, [4, 0])))
c = c.permute(0, 2, 1)
c, _ = self.gru(c, state)
return c

View file

@ -0,0 +1,90 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
from utils.complexity import _conv1d_flop_count
class SilkFeatureNetPL(nn.Module):
""" feature net with partial lookahead """
def __init__(self,
feature_dim=47,
num_channels=256,
hidden_feature_dim=64):
super(SilkFeatureNetPL, self).__init__()
self.feature_dim = feature_dim
self.num_channels = num_channels
self.hidden_feature_dim = hidden_feature_dim
self.conv1 = nn.Conv1d(feature_dim, self.hidden_feature_dim, 1)
self.conv2 = nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2)
self.tconv = nn.ConvTranspose1d(num_channels, num_channels, 4, 4)
self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
def flop_count(self, rate=200):
count = 0
for conv in self.conv1, self.conv2, self.tconv:
count += _conv1d_flop_count(conv, rate)
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
return count
def forward(self, features, state=None):
""" features shape: (batch_size, num_frames, feature_dim) """
batch_size = features.size(0)
num_frames = features.size(1)
if state is None:
state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
features = features.permute(0, 2, 1)
# dimensionality reduction
c = torch.tanh(self.conv1(features))
# frame accumulation
c = c.permute(0, 2, 1)
c = c.reshape(batch_size, num_frames // 4, -1).permute(0, 2, 1)
c = torch.tanh(self.conv2(F.pad(c, [1, 0])))
# upsampling
c = self.tconv(c)
c = c.permute(0, 2, 1)
c, _ = self.gru(c, state)
return c

View file

@ -0,0 +1,96 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import argparse
import torch
from scipy.io import wavfile
from models import model_dict
from utils.silk_features import load_inference_data
from utils import endoscopy
debug = False
if debug:
args = type('dummy', (object,),
{
'input' : 'testitems/all_0_orig.se',
'checkpoint' : 'testout/checkpoints/checkpoint_epoch_5.pth',
'output' : 'out.wav',
})()
else:
parser = argparse.ArgumentParser()
parser.add_argument('input', type=str, help='path to folder with features and signals')
parser.add_argument('checkpoint', type=str, help='checkpoint file')
parser.add_argument('output', type=str, help='output file')
parser.add_argument('--debug', action='store_true', help='enables debug output')
args = parser.parse_args()
torch.set_num_threads(2)
input_folder = args.input
checkpoint_file = args.checkpoint
output_file = args.output
if not output_file.endswith('.wav'):
output_file += '.wav'
checkpoint = torch.load(checkpoint_file, map_location="cpu")
# check model
if not 'name' in checkpoint['setup']['model']:
print(f'warning: did not find model name entry in setup, using pitchpostfilter per default')
model_name = 'pitchpostfilter'
else:
model_name = checkpoint['setup']['model']['name']
model = model_dict[model_name](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
model.load_state_dict(checkpoint['state_dict'])
# generate model input
setup = checkpoint['setup']
signal, features, periods, numbits = load_inference_data(input_folder, **setup['data'])
if args.debug:
endoscopy.init()
output = model.process(signal, features, periods, numbits, debug=args.debug)
wavfile.write(output_file, 16000, output.cpu().numpy())
if args.debug:
endoscopy.close()

View file

@ -0,0 +1,297 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import sys
import yaml
try:
import git
has_git = True
except:
has_git = False
import torch
from torch.optim.lr_scheduler import LambdaLR
import numpy as np
from scipy.io import wavfile
import pesq
from data import SilkEnhancementSet
from models import model_dict
from engine.engine import train_one_epoch, evaluate
from utils.silk_features import load_inference_data
from utils.misc import count_parameters
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
parser = argparse.ArgumentParser()
parser.add_argument('setup', type=str, help='setup yaml file')
parser.add_argument('output', type=str, help='output path')
parser.add_argument('--device', type=str, help='compute device', default=None)
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
parser.add_argument('--testdata', type=str, help='path to features and signal for testing', default=None)
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
args = parser.parse_args()
torch.set_num_threads(4)
with open(args.setup, 'r') as f:
setup = yaml.load(f.read(), yaml.FullLoader)
checkpoint_prefix = 'checkpoint'
output_prefix = 'output'
setup_name = 'setup.yml'
output_file='out.txt'
# check model
if not 'name' in setup['model']:
print(f'warning: did not find model entry in setup, using default PitchPostFilter')
model_name = 'pitchpostfilter'
else:
model_name = setup['model']['name']
# prepare output folder
if os.path.exists(args.output):
print("warning: output folder exists")
reply = input('continue? (y/n): ')
while reply not in {'y', 'n'}:
reply = input('continue? (y/n): ')
if reply == 'n':
os._exit()
else:
os.makedirs(args.output, exist_ok=True)
checkpoint_dir = os.path.join(args.output, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
# add repo info to setup
if has_git:
working_dir = os.path.split(__file__)[0]
try:
repo = git.Repo(working_dir)
setup['repo'] = dict()
hash = repo.head.object.hexsha
urls = list(repo.remote().urls)
is_dirty = repo.is_dirty()
if is_dirty:
print("warning: repo is dirty")
setup['repo']['hash'] = hash
setup['repo']['urls'] = urls
setup['repo']['dirty'] = is_dirty
except:
has_git = False
# dump setup
with open(os.path.join(args.output, setup_name), 'w') as f:
yaml.dump(setup, f)
ref = None
if args.testdata is not None:
testsignal, features, periods, numbits = load_inference_data(args.testdata, **setup['data'])
inference_test = True
inference_folder = os.path.join(args.output, 'inference_test')
os.makedirs(os.path.join(args.output, 'inference_test'), exist_ok=True)
try:
ref = np.fromfile(os.path.join(args.testdata, 'clean.s16'), dtype=np.int16)
except:
pass
else:
inference_test = False
# training parameters
batch_size = setup['training']['batch_size']
epochs = setup['training']['epochs']
lr = setup['training']['lr']
lr_decay_factor = setup['training']['lr_decay_factor']
# load training dataset
data_config = setup['data']
data = SilkEnhancementSet(setup['dataset'], **data_config)
# load validation dataset if given
if 'validation_dataset' in setup:
validation_data = SilkEnhancementSet(setup['validation_dataset'], **data_config)
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=8)
run_validation = True
else:
run_validation = False
# create model
model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
if args.initial_checkpoint is not None:
print(f"loading state dict from {args.initial_checkpoint}...")
chkpt = torch.load(args.initial_checkpoint, map_location='cpu')
model.load_state_dict(chkpt['state_dict'])
# set compute device
if type(args.device) == type(None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
# push model to device
model.to(device)
# dataloader
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=8)
# optimizer is introduced to trainable parameters
parameters = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(parameters, lr=lr)
# learning rate scheduler
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
# loss
w_l1 = setup['training']['loss']['w_l1']
w_lm = setup['training']['loss']['w_lm']
w_slm = setup['training']['loss']['w_slm']
w_sc = setup['training']['loss']['w_sc']
w_logmel = setup['training']['loss']['w_logmel']
w_wsc = setup['training']['loss']['w_wsc']
w_xcorr = setup['training']['loss']['w_xcorr']
w_sxcorr = setup['training']['loss']['w_sxcorr']
w_l2 = setup['training']['loss']['w_l2']
w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
logmelloss = MRLogMelLoss().to(device)
def xcorr_loss(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
return torch.mean(loss)
def td_l2_norm(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
return loss.mean()
def td_l1(y_true, y_pred, pow=0):
dims = list(range(1, len(y_true.shape)))
tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
return torch.mean(tmp)
def criterion(x, y):
return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
# model checkpoint
checkpoint = {
'setup' : setup,
'state_dict' : model.state_dict(),
'loss' : -1
}
if not args.no_redirect:
print(f"re-directing output to {os.path.join(args.output, output_file)}")
sys.stdout = open(os.path.join(args.output, output_file), "w")
print("summary:")
print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
if hasattr(model, 'flop_count'):
print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS")
if ref is not None:
noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16)
initial_mos = pesq.pesq(16000, ref, noisy, mode='wb')
print(f"initial MOS (PESQ): {initial_mos}")
best_loss = 1e9
for ep in range(1, epochs + 1):
print(f"training epoch {ep}...")
new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler)
# save checkpoint
checkpoint['state_dict'] = model.state_dict()
checkpoint['loss'] = new_loss
if run_validation:
print("running validation...")
validation_loss = evaluate(model, criterion, validation_dataloader, device)
checkpoint['validation_loss'] = validation_loss
if validation_loss < best_loss:
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth'))
best_loss = validation_loss
if inference_test:
print("running inference test...")
out = model.process(testsignal, features, periods, numbits).cpu().numpy()
wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
if ref is not None:
mos = pesq.pesq(16000, ref, out, mode='wb')
print(f"MOS (PESQ): {mos}")
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
print()
print('Done')

View file

@ -0,0 +1,35 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
def _conv1d_flop_count(layer, rate):
return 2 * ((layer.in_channels + 1) * layer.out_channels * rate / layer.stride[0] ) * layer.kernel_size[0]
def _dense_flop_count(layer, rate):
return 2 * ((layer.in_features + 1) * layer.out_features * rate )

View file

@ -0,0 +1,234 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
""" module for inspecting models during inference """
import os
import yaml
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch
import numpy as np
# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
_state = dict()
_folder = 'endoscopy'
def get_gru_gates(gru, input, state):
hidden_size = gru.hidden_size
direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
# reset gate
start, stop = 0 * hidden_size, 1 * hidden_size
reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
# update gate
start, stop = 1 * hidden_size, 2 * hidden_size
update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
# new gate
start, stop = 2 * hidden_size, 3 * hidden_size
new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop]))
return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
def init(folder='endoscopy'):
""" sets up output folder for endoscopy data """
global _folder
_folder = folder
if not os.path.exists(folder):
os.makedirs(folder)
else:
print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
def write_data(key, data, fs):
""" appends data to previous data written under key """
global _state
# convert to numpy if torch.Tensor is given
if isinstance(data, torch.Tensor):
data = data.detach().numpy()
if not key in _state:
_state[key] = {
'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'),
'fs' : fs,
'dim' : tuple(data.shape),
'dtype' : str(data.dtype)
}
with open(os.path.join(_folder, key + '.yml'), 'w') as f:
f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
else:
if _state[key]['fs'] != fs:
raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
if _state[key]['dtype'] != str(data.dtype):
raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
if _state[key]['dim'] != tuple(data.shape):
raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
_state[key]['fid'].write(data.tobytes())
def close(folder='endoscopy'):
""" clean up """
for key in _state.keys():
_state[key]['fid'].close()
def read_data(folder='endoscopy'):
""" retrieves written data as numpy arrays """
keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
return_dict = dict()
for key in keys:
with open(os.path.join(folder, key + '.yml'), 'r') as f:
value = yaml.load(f.read(), yaml.FullLoader)
with open(os.path.join(folder, key + '.bin'), 'rb') as f:
data = np.frombuffer(f.read(), dtype=value['dtype'])
value['data'] = data.reshape((-1,) + value['dim'])
return_dict[key] = value
return return_dict
def get_best_reshape(shape, target_ratio=1):
""" calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
if len(shape) > 1:
pixel_count = 1
for s in shape:
pixel_count *= s
else:
pixel_count = shape[0]
if pixel_count == 1:
return (1,)
num_columns = int((pixel_count / target_ratio)**.5)
while (pixel_count % num_columns):
num_columns -= 1
num_rows = pixel_count // num_columns
return (num_rows, num_columns)
def get_type_and_shape(shape):
# can happen if data is one dimensional
if len(shape) == 0:
shape = (1,)
# calculate pixel count
if len(shape) > 1:
pixel_count = 1
for s in shape:
pixel_count *= s
else:
pixel_count = shape[0]
if pixel_count == 1:
return 'plot', (1, )
# stay with shape if already 2-dimensional
if len(shape) == 2:
if (shape[0] != pixel_count) or (shape[1] != pixel_count):
return 'image', shape
return 'image', get_best_reshape(shape)
def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
# determine plot setup
num_keys = len(data.keys())
num_rows = int((num_keys * 3/4) ** .5)
num_cols = (num_keys + num_rows - 1) // num_rows
fig, axs = plt.subplots(num_rows, num_cols)
fig.set_size_inches(num_cols * 5, num_rows * 5)
display = dict()
fs_max = max([val['fs'] for val in data.values()])
num_samples = max([val['data'].shape[0] for val in data.values()])
keys = sorted(data.keys())
# inspect data
for i, key in enumerate(keys):
axs[i // num_cols, i % num_cols].title.set_text(key)
display[key] = dict()
display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
display[key]['down_factor'] = data[key]['fs'] / fs_max
start_index = max(start_index, half_signal_window_length)
while stop_index < 0:
stop_index += num_samples
stop_index = min(stop_index, num_samples - half_signal_window_length)
# actual plotting
frames = []
for index in range(start_index, stop_index):
ims = []
for i, key in enumerate(keys):
feature_index = int(round(index * display[key]['down_factor']))
if display[key]['type'] == 'plot':
ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
elif display[key]['type'] == 'image':
ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
frames.append(ims)
ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
if not filename.endswith('.mp4'):
filename += '.mp4'
ani.save(filename)

View file

@ -0,0 +1,236 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
from utils.endoscopy import write_data
class LimitedAdaptiveComb1d(nn.Module):
COUNTER = 1
def __init__(self,
kernel_size,
feature_dim,
frame_size=160,
overlap_size=40,
use_bias=True,
padding=None,
max_lag=256,
name=None,
gain_limit_db=10,
global_gain_limits_db=[-6, 6],
norm_p=2):
"""
Parameters:
-----------
feature_dim : int
dimension of features from which kernels, biases and gains are computed
frame_size : int, optional
frame size, defaults to 160
overlap_size : int, optional
overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame, defaults to 40
use_bias : bool, optional
if true, biases will be added to output channels. Defaults to True
padding : List[int, int], optional
left and right padding. Defaults to [(kernel_size - 1) // 2, kernel_size - 1 - (kernel_size - 1) // 2]
max_lag : int, optional
maximal pitch lag, defaults to 256
have_a0 : bool, optional
If true, the filter coefficient a0 will be learned as a positive gain (requires in_channels == out_channels). Otherwise, a0 is set to 0. Defaults to False
name: str or None, optional
specifies a name attribute for the module. If None the name is auto generated as comb_1d_COUNT, where COUNT is an instance counter for LimitedAdaptiveComb1d
"""
super(LimitedAdaptiveComb1d, self).__init__()
self.in_channels = 1
self.out_channels = 1
self.feature_dim = feature_dim
self.kernel_size = kernel_size
self.frame_size = frame_size
self.overlap_size = overlap_size
self.use_bias = use_bias
self.max_lag = max_lag
self.limit_db = gain_limit_db
self.norm_p = norm_p
if name is None:
self.name = "limited_adaptive_comb1d_" + str(LimitedAdaptiveComb1d.COUNTER)
LimitedAdaptiveComb1d.COUNTER += 1
else:
self.name = name
# network for generating convolution weights
self.conv_kernel = nn.Linear(feature_dim, kernel_size)
if self.use_bias:
self.conv_bias = nn.Linear(feature_dim,1)
# comb filter gain
self.filter_gain = nn.Linear(feature_dim, 1)
self.log_gain_limit = gain_limit_db * 0.11512925464970229
with torch.no_grad():
self.filter_gain.bias[:] = max(0.1, 4 + self.log_gain_limit)
self.global_filter_gain = nn.Linear(feature_dim, 1)
log_min, log_max = global_gain_limits_db[0] * 0.11512925464970229, global_gain_limits_db[1] * 0.11512925464970229
self.filter_gain_a = (log_max - log_min) / 2
self.filter_gain_b = (log_max + log_min) / 2
if type(padding) == type(None):
self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2]
else:
self.padding = padding
self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False)
def forward(self, x, features, lags, debug=False):
""" adaptive 1d convolution
Parameters:
-----------
x : torch.tensor
input signal of shape (batch_size, in_channels, num_samples)
feathres : torch.tensor
frame-wise features of shape (batch_size, num_frames, feature_dim)
lags: torch.LongTensor
frame-wise lags for comb-filtering
"""
batch_size = x.size(0)
num_frames = features.size(1)
num_samples = x.size(2)
frame_size = self.frame_size
overlap_size = self.overlap_size
kernel_size = self.kernel_size
win1 = torch.flip(self.overlap_win, [0])
win2 = self.overlap_win
if num_samples // self.frame_size != num_frames:
raise ValueError('non matching sizes in AdaptiveConv1d.forward')
conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size))
conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=-1, keepdim=True))
if self.use_bias:
conv_biases = self.conv_bias(features).permute(0, 2, 1)
conv_gains = torch.exp(- torch.relu(self.filter_gain(features).permute(0, 2, 1)) + self.log_gain_limit)
# calculate gains
global_conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.global_filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b)
if debug and batch_size == 1:
key = self.name + "_gains"
write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
key = self.name + "_kernels"
write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
key = self.name + "_lags"
write_data(key, lags.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
key = self.name + "_global_conv_gains"
write_data(key, global_conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
# frame-wise convolution with overlap-add
output_frames = []
overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device)
x = F.pad(x, self.padding)
x = F.pad(x, [self.max_lag, self.overlap_size])
idx = torch.arange(frame_size + kernel_size - 1 + overlap_size).to(x.device).view(1, 1, -1)
idx = torch.repeat_interleave(idx, batch_size, 0)
idx = torch.repeat_interleave(idx, self.in_channels, 1)
for i in range(num_frames):
cidx = idx + i * frame_size + self.max_lag - lags[..., i].view(batch_size, 1, 1)
xx = torch.gather(x, -1, cidx).reshape((1, batch_size * self.in_channels, -1))
new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1)
if self.use_bias:
new_chunk = new_chunk + conv_biases[:, :, i : i + 1]
offset = self.max_lag + self.padding[0]
new_chunk = global_conv_gains[:, :, i : i + 1] * (new_chunk * conv_gains[:, :, i : i + 1] + x[..., offset + i * frame_size : offset + (i + 1) * frame_size + overlap_size])
# overlapping part
output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2)
# non-overlapping part
output_frames.append(new_chunk[:, :, overlap_size : frame_size])
# mem for next frame
overlap_mem = new_chunk[:, :, frame_size :]
# concatenate chunks
output = torch.cat(output_frames, dim=-1)
return output
def flop_count(self, rate):
frame_rate = rate / self.frame_size
overlap = self.overlap_size
overhead = overlap / self.frame_size
count = 0
# kernel computation and filtering
count += 2 * (frame_rate * self.feature_dim * self.kernel_size)
count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
# bias computation
if self.use_bias:
count += 2 * (frame_rate * self.feature_dim) + rate * (1 + overhead)
# a0 computation
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
# windowing
count += overlap * frame_rate * 3 * self.out_channels
return count

View file

@ -0,0 +1,222 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
from utils.endoscopy import write_data
class LimitedAdaptiveConv1d(nn.Module):
COUNTER = 1
def __init__(self,
in_channels,
out_channels,
kernel_size,
feature_dim,
frame_size=160,
overlap_size=40,
use_bias=True,
padding=None,
name=None,
gain_limits_db=[-6, 6],
shape_gain_db=0,
norm_p=2):
"""
Parameters:
-----------
in_channels : int
number of input channels
out_channels : int
number of output channels
feature_dim : int
dimension of features from which kernels, biases and gains are computed
frame_size : int
frame size
overlap_size : int
overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame
use_bias : bool
if true, biases will be added to output channels
padding : List[int, int]
"""
super(LimitedAdaptiveConv1d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.feature_dim = feature_dim
self.kernel_size = kernel_size
self.frame_size = frame_size
self.overlap_size = overlap_size
self.use_bias = use_bias
self.gain_limits_db = gain_limits_db
self.shape_gain_db = shape_gain_db
self.norm_p = norm_p
if name is None:
self.name = "limited_adaptive_conv1d_" + str(LimitedAdaptiveConv1d.COUNTER)
LimitedAdaptiveConv1d.COUNTER += 1
else:
self.name = name
# network for generating convolution weights
self.conv_kernel = nn.Linear(feature_dim, in_channels * out_channels * kernel_size)
if self.use_bias:
self.conv_bias = nn.Linear(feature_dim, out_channels)
self.shape_gain = min(1, 10**(shape_gain_db / 20))
self.filter_gain = nn.Linear(feature_dim, out_channels)
log_min, log_max = gain_limits_db[0] * 0.11512925464970229, gain_limits_db[1] * 0.11512925464970229
self.filter_gain_a = (log_max - log_min) / 2
self.filter_gain_b = (log_max + log_min) / 2
if type(padding) == type(None):
self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2]
else:
self.padding = padding
self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False)
def flop_count(self, rate):
frame_rate = rate / self.frame_size
overlap = self.overlap_size
overhead = overlap / self.frame_size
count = 0
# kernel computation and filtering
count += 2 * (frame_rate * self.feature_dim * self.kernel_size)
count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
# bias computation
if self.use_bias:
count += 2 * (frame_rate * self.feature_dim) + rate * (1 + overhead)
# gain computation
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
# windowing
count += 3 * overlap * frame_rate * self.out_channels
return count
def forward(self, x, features, debug=False):
""" adaptive 1d convolution
Parameters:
-----------
x : torch.tensor
input signal of shape (batch_size, in_channels, num_samples)
feathres : torch.tensor
frame-wise features of shape (batch_size, num_frames, feature_dim)
"""
batch_size = x.size(0)
num_frames = features.size(1)
num_samples = x.size(2)
frame_size = self.frame_size
overlap_size = self.overlap_size
kernel_size = self.kernel_size
win1 = torch.flip(self.overlap_win, [0])
win2 = self.overlap_win
if num_samples // self.frame_size != num_frames:
raise ValueError('non matching sizes in AdaptiveConv1d.forward')
conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size))
# normalize kernels (TODO: switch to L1 and normalize over kernel and input channel dimension)
conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=[-2, -1], keepdim=True))
# limit shape
id_kernels = torch.zeros_like(conv_kernels)
id_kernels[..., self.padding[1]] = 1
conv_kernels = self.shape_gain * conv_kernels + (1 - self.shape_gain) * id_kernels
if self.use_bias:
conv_biases = self.conv_bias(features).permute(0, 2, 1)
# calculate gains
conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b)
if debug and batch_size == 1:
key = self.name + "_gains"
write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
key = self.name + "_kernels"
write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
# frame-wise convolution with overlap-add
output_frames = []
overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device)
x = F.pad(x, self.padding)
x = F.pad(x, [0, self.overlap_size])
for i in range(num_frames):
xx = x[:, :, i * frame_size : (i + 1) * frame_size + kernel_size - 1 + overlap_size].reshape((1, batch_size * self.in_channels, -1))
new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1)
if self.use_bias:
new_chunk = new_chunk + conv_biases[:, :, i : i + 1]
new_chunk = new_chunk * conv_gains[:, :, i : i + 1]
# overlapping part
output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2)
# non-overlapping part
output_frames.append(new_chunk[:, :, overlap_size : frame_size])
# mem for next frame
overlap_mem = new_chunk[:, :, frame_size :]
# concatenate chunks
output = torch.cat(output_frames, dim=-1)
return output

View file

@ -0,0 +1,84 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
class PitchAutoCorrelator(nn.Module):
def __init__(self,
frame_size=80,
pitch_min=32,
pitch_max=300,
radius=2):
super().__init__()
self.frame_size = frame_size
self.pitch_min = pitch_min
self.pitch_max = pitch_max
self.radius = radius
def forward(self, x, periods):
# x of shape (batch_size, channels, num_samples)
# periods of shape (batch_size, num_frames)
num_frames = periods.size(1)
batch_size = periods.size(0)
num_samples = self.frame_size * num_frames
channels = x.size(1)
assert num_samples == x.size(-1)
range = torch.arange(-self.radius, self.radius + 1, device=x.device)
idx = torch.arange(self.frame_size * num_frames, device=x.device)
p_up = torch.repeat_interleave(periods, self.frame_size, 1)
lookup = idx + self.pitch_max - p_up
lookup = lookup.unsqueeze(-1) + range
lookup = lookup.unsqueeze(1)
# padding
x_pad = F.pad(x, [self.pitch_max, 0])
x_ext = torch.repeat_interleave(x_pad.unsqueeze(-1), 2 * self.radius + 1, -1)
# framing
x_select = torch.gather(x_ext, 2, lookup)
x_frames = x_pad[..., self.pitch_max : ].reshape(batch_size, channels, num_frames, self.frame_size, 1)
lag_frames = x_select.reshape(batch_size, 1, num_frames, self.frame_size, -1)
# calculate auto-correlation
dotp = torch.sum(x_frames * lag_frames, dim=-2)
frame_nrg = torch.sum(x_frames * x_frames, dim=-2)
lag_frame_nrg = torch.sum(lag_frames * lag_frames, dim=-2)
acorr = dotp / torch.sqrt(frame_nrg * lag_frame_nrg + 1e-9)
return acorr

View file

@ -0,0 +1,42 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
def count_parameters(model, verbose=False):
total = 0
for name, p in model.named_parameters():
count = torch.ones_like(p).sum().item()
if verbose:
print(f"{name}: {count} parameters")
total += count
return total

View file

@ -0,0 +1,121 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import numpy as np
def hangover(lags, num_frames=10):
lags = lags.copy()
count = 0
last_lag = 0
for i in range(len(lags)):
lag = lags[i]
if lag == 0:
if count < num_frames:
lags[i] = last_lag
count += 1
else:
count = 0
return lags
def smooth_pitch_lags(lags, d=2):
assert d < 4
num_silk_frames = len(lags) // 4
smoothed_lags = lags.copy()
tmp = np.arange(1, d+1)
kernel = np.concatenate((tmp, [d+1], tmp[::-1]), dtype=np.float32)
kernel = kernel / np.sum(kernel)
last = lags[0:d][::-1]
for i in range(num_silk_frames):
frame = lags[i * 4: (i+1) * 4]
if np.max(np.abs(frame)) == 0:
last = frame[4-d:]
continue
if i == num_silk_frames - 1:
next = frame[4-d:][::-1]
else:
next = lags[(i+1) * 4 : (i+1) * 4 + d]
if np.max(np.abs(next)) == 0:
next = frame[4-d:][::-1]
if np.max(np.abs(last)) == 0:
last = frame[0:d][::-1]
smoothed_frame = np.convolve(np.concatenate((last, frame, next), dtype=np.float32), kernel, mode='valid')
smoothed_lags[i * 4: (i+1) * 4] = np.round(smoothed_frame)
last = frame[4-d:]
return smoothed_lags
def calculate_acorr_window(x, frame_size, lags, history=None, max_lag=300, radius=2, add_double_lag_acorr=False, no_pitch_threshold=32):
eps = 1e-9
lag_multiplier = 2 if add_double_lag_acorr else 1
if history is None:
history = np.zeros(lag_multiplier * max_lag + radius, dtype=x.dtype)
offset = len(history)
assert offset >= max_lag + radius
assert len(x) % frame_size == 0
num_frames = len(x) // frame_size
lags = lags.copy()
x_ext = np.concatenate((history, x), dtype=x.dtype)
d = radius
num_acorrs = 2 * d + 1
acorrs = np.zeros((num_frames, lag_multiplier * num_acorrs), dtype=x.dtype)
for idx in range(num_frames):
lag = lags[idx].item()
frame = x_ext[offset + idx * frame_size : offset + (idx + 1) * frame_size]
for k in range(lag_multiplier):
lag1 = (k + 1) * lag if lag >= no_pitch_threshold else lag
for j in range(num_acorrs):
past = x_ext[offset + idx * frame_size - lag1 + j - d : offset + (idx + 1) * frame_size - lag1 + j - d]
acorrs[idx, j + k * num_acorrs] = np.dot(frame, past) / np.sqrt(np.dot(frame, frame) * np.dot(past, past) + eps)
return acorrs, lags

View file

@ -0,0 +1,151 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import numpy as np
import torch
import scipy
from utils.pitch import hangover, calculate_acorr_window
from utils.spec import create_filter_bank, cepstrum, log_spectrum, log_spectrum_from_lpc
def spec_from_lpc(a, n_fft=128, eps=1e-9):
order = a.shape[-1]
assert order + 1 < n_fft
x = np.zeros((*a.shape[:-1], n_fft ))
x[..., 0] = 1
x[..., 1:1 + order] = -a
X = np.fft.fft(x, axis=-1)
X = np.abs(X[..., :n_fft//2 + 1]) ** 2
S = 1 / (X + eps)
return S
def silk_feature_factory(no_pitch_value=256,
acorr_radius=2,
pitch_hangover=8,
num_bands_clean_spec=64,
num_bands_noisy_spec=18,
noisy_spec_scale='opus',
noisy_apply_dct=True,
add_offset=False,
add_double_lag_acorr=False
):
w = scipy.signal.windows.cosine(320)
fb_clean_spec = create_filter_bank(num_bands_clean_spec, 320, scale='erb', round_center_bins=True, normalize=True)
fb_noisy_spec = create_filter_bank(num_bands_noisy_spec, 320, scale=noisy_spec_scale, round_center_bins=True, normalize=True)
def create_features(noisy, noisy_history, lpcs, gains, ltps, periods, offsets):
periods = periods.copy()
if pitch_hangover > 0:
periods = hangover(periods, num_frames=pitch_hangover)
periods[periods == 0] = no_pitch_value
clean_spectrum = 0.3 * log_spectrum_from_lpc(lpcs, fb=fb_clean_spec, n_fft=320)
if noisy_apply_dct:
noisy_cepstrum = np.repeat(
cepstrum(np.concatenate((noisy_history[-160:], noisy), dtype=np.float32), 320, fb_noisy_spec, w), 2, 0)
else:
noisy_cepstrum = np.repeat(
log_spectrum(np.concatenate((noisy_history[-160:], noisy), dtype=np.float32), 320, fb_noisy_spec, w), 2, 0)
log_gains = np.log(gains + 1e-9).reshape(-1, 1)
acorr, _ = calculate_acorr_window(noisy, 80, periods, noisy_history, radius=acorr_radius, add_double_lag_acorr=add_double_lag_acorr)
if add_offset:
features = np.concatenate((clean_spectrum, noisy_cepstrum, acorr, ltps, log_gains, offsets.reshape(-1, 1)), axis=-1, dtype=np.float32)
else:
features = np.concatenate((clean_spectrum, noisy_cepstrum, acorr, ltps, log_gains), axis=-1, dtype=np.float32)
return features, periods.astype(np.int64)
return create_features
def load_inference_data(path,
no_pitch_value=256,
skip=92,
preemph=0.85,
acorr_radius=2,
pitch_hangover=8,
num_bands_clean_spec=64,
num_bands_noisy_spec=18,
noisy_spec_scale='opus',
noisy_apply_dct=True,
add_offset=False,
add_double_lag_acorr=False,
**kwargs):
print(f"[load_inference_data]: ignoring keyword arguments {kwargs.keys()}...")
lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16)
ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5)
gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32).astype(np.float32).reshape(-1, 1)
num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32).reshape(-1, 1)
offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32)
# load signal, add back delay and pre-emphasize
signal = np.fromfile(os.path.join(path, 'noisy.s16'), dtype=np.int16).astype(np.float32) / (2 ** 15)
signal = np.concatenate((np.zeros(skip, dtype=np.float32), signal), dtype=np.float32)
create_features = silk_feature_factory(no_pitch_value, acorr_radius, pitch_hangover, num_bands_clean_spec, num_bands_noisy_spec, noisy_spec_scale, noisy_apply_dct, add_offset, add_double_lag_acorr)
num_frames = min((len(signal) // 320) * 4, len(lpcs))
signal = signal[: num_frames * 80]
lpcs = lpcs[: num_frames]
ltps = ltps[: num_frames]
gains = gains[: num_frames]
periods = periods[: num_frames]
num_bits = num_bits[: num_frames // 4]
num_bits_smooth = num_bits[: num_frames // 4]
offsets = offsets[: num_frames]
numbits = np.repeat(np.concatenate((num_bits, num_bits_smooth), axis=-1, dtype=np.float32), 4, axis=0)
features, periods = create_features(signal, np.zeros(350, dtype=signal.dtype), lpcs, gains, ltps, periods, offsets)
if preemph > 0:
signal[1:] -= preemph * signal[:-1]
return torch.from_numpy(signal), torch.from_numpy(features), torch.from_numpy(periods), torch.from_numpy(numbits)

View file

@ -0,0 +1,194 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import math as m
import numpy as np
import scipy
def erb(f):
return 24.7 * (4.37 * f + 1)
def inv_erb(e):
return (e / 24.7 - 1) / 4.37
def bark(f):
return 6 * m.asinh(f/600)
def inv_bark(b):
return 600 * m.sinh(b / 6)
scale_dict = {
'bark': [bark, inv_bark],
'erb': [erb, inv_erb]
}
def create_filter_bank(num_bands, n_fft=320, fs=16000, scale='bark', round_center_bins=False, return_upper=False, normalize=False):
f0 = 0
num_bins = n_fft // 2 + 1
f1 = fs / n_fft * (num_bins - 1)
fstep = fs / n_fft
if scale == 'opus':
bins_5ms = [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40]
fac = 1000 * n_fft / fs / 5
if num_bands != 18:
print("warning: requested Opus filter bank with num_bands != 18. Adjusting num_bands.")
num_bands = 18
center_bins = np.array([fac * bin for bin in bins_5ms])
else:
to_scale, from_scale = scale_dict[scale]
s0 = to_scale(f0)
s1 = to_scale(f1)
center_freqs = np.array([f0] + [from_scale(s0 + i * (s1 - s0) / (num_bands)) for i in range(1, num_bands - 1)] + [f1])
center_bins = (center_freqs - f0) / fstep
if round_center_bins:
center_bins = np.round(center_bins)
filter_bank = np.zeros((num_bands, num_bins))
band = 0
for bin in range(num_bins):
# update band index
if bin > center_bins[band + 1]:
band += 1
# calculate filter coefficients
frac = (center_bins[band + 1] - bin) / (center_bins[band + 1] - center_bins[band])
filter_bank[band][bin] = frac
filter_bank[band + 1][bin] = 1 - frac
if return_upper:
extend = n_fft - num_bins
filter_bank = np.concatenate((filter_bank, np.fliplr(filter_bank[:, 1:extend+1])), axis=1)
if normalize:
filter_bank = filter_bank / np.sum(filter_bank, axis=1).reshape(-1, 1)
return filter_bank
def compressed_log_spec(pspec):
lpspec = np.zeros_like(pspec)
num_bands = pspec.shape[-1]
log_max = -2
follow = -2
for i in range(num_bands):
tmp = np.log10(pspec[i] + 1e-9)
tmp = max(log_max, max(follow - 2.5, tmp))
lpspec[i] = tmp
log_max = max(log_max, tmp)
follow = max(follow - 2.5, tmp)
return lpspec
def log_spectrum_from_lpc(a, fb=None, n_fft=320, eps=1e-9, gamma=1, compress=False, power=1):
""" calculates cepstrum from SILK lpcs """
order = a.shape[-1]
assert order + 1 < n_fft
a = a * (gamma ** (1 + np.arange(order))).astype(np.float32)
x = np.zeros((*a.shape[:-1], n_fft ))
x[..., 0] = 1
x[..., 1:1 + order] = -a
X = np.fft.fft(x, axis=-1)
X = np.abs(X[..., :n_fft//2 + 1]) ** power
S = 1 / (X + eps)
if fb is None:
Sf = S
else:
Sf = np.matmul(S, fb.T)
if compress:
Sf = np.apply_along_axis(compressed_log_spec, -1, Sf)
else:
Sf = np.log(Sf + eps)
return Sf
def cepstrum_from_lpc(a, fb=None, n_fft=320, eps=1e-9, gamma=1, compress=False):
""" calculates cepstrum from SILK lpcs """
Sf = log_spectrum_from_lpc(a, fb, n_fft, eps, gamma, compress)
cepstrum = scipy.fftpack.dct(Sf, 2, norm='ortho')
return cepstrum
def log_spectrum(x, frame_size, fb=None, window=None, power=1):
""" calculate cepstrum on 50% overlapping frames """
assert(2*len(x)) % frame_size == 0
assert frame_size % 2 == 0
n = len(x)
num_even = n // frame_size
num_odd = (n - frame_size // 2) // frame_size
num_bins = frame_size // 2 + 1
x_even = x[:num_even * frame_size].reshape(-1, frame_size)
x_odd = x[frame_size//2 : frame_size//2 + frame_size * num_odd].reshape(-1, frame_size)
x_unfold = np.empty((x_even.size + x_odd.size), dtype=x.dtype).reshape((-1, frame_size))
x_unfold[::2, :] = x_even
x_unfold[1::2, :] = x_odd
if window is not None:
x_unfold *= window.reshape(1, -1)
X = np.abs(np.fft.fft(x_unfold, n=frame_size, axis=-1))[:, :num_bins] ** power
if fb is not None:
X = np.matmul(X, fb.T)
return np.log(X + 1e-9)
def cepstrum(x, frame_size, fb=None, window=None):
""" calculate cepstrum on 50% overlapping frames """
X = log_spectrum(x, frame_size, fb, window)
cepstrum = scipy.fftpack.dct(X, 2, norm='ortho')
return cepstrum

View file

@ -0,0 +1,92 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
setup_dict = dict()
lace_setup = {
'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training',
'validation_dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/validation',
'model': {
'name': 'lace',
'args': [],
'kwargs': {
'comb_gain_limit_db': 10,
'cond_dim': 128,
'conv_gain_limits_db': [-12, 12],
'global_gain_limits_db': [-6, 6],
'hidden_feature_dim': 96,
'kernel_size': 15,
'num_features': 93,
'numbits_embedding_dim': 8,
'numbits_range': [50, 650],
'partial_lookahead': True,
'pitch_embedding_dim': 64,
'pitch_max': 300,
'preemph': 0.85,
'skip': 91
}
},
'data': {
'frames_per_sample': 100,
'no_pitch_value': 7,
'preemph': 0.85,
'skip': 91,
'pitch_hangover': 8,
'acorr_radius': 2,
'num_bands_clean_spec': 64,
'num_bands_noisy_spec': 18,
'noisy_spec_scale': 'opus',
'pitch_hangover': 8,
},
'training': {
'batch_size': 256,
'lr': 5.e-4,
'lr_decay_factor': 2.5e-5,
'epochs': 50,
'frames_per_sample': 50,
'loss': {
'w_l1': 0,
'w_lm': 0,
'w_logmel': 0,
'w_sc': 0,
'w_wsc': 0,
'w_xcorr': 0,
'w_sxcorr': 1,
'w_l2': 10,
'w_slm': 2
}
}
}
setup_dict = {
'lace': lace_setup,
}