mirror of
https://github.com/xiph/opus.git
synced 2025-05-29 14:49:14 +00:00
Opus ng lace
This commit is contained in:
parent
178672ed18
commit
105e1d83fa
24 changed files with 2937 additions and 0 deletions
4
dnn/torch/osce/README.md
Normal file
4
dnn/torch/osce/README.md
Normal file
|
@ -0,0 +1,4 @@
|
|||
# Opus Speech Coding Enhancement
|
||||
|
||||
This folder hosts models for enhancing SILK. See related Opus repo https://gitlab.xiph.org/xiph/opus/-/tree/exp-neural-silk-enhancement
|
||||
for feature generation.
|
30
dnn/torch/osce/data/__init__.py
Normal file
30
dnn/torch/osce/data/__init__.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
from .silk_enhancement_set import SilkEnhancementSet
|
140
dnn/torch/osce/data/silk_enhancement_set.py
Normal file
140
dnn/torch/osce/data/silk_enhancement_set.py
Normal file
|
@ -0,0 +1,140 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
|
||||
from utils.silk_features import silk_feature_factory
|
||||
from utils.pitch import hangover, calculate_acorr_window
|
||||
|
||||
|
||||
class SilkEnhancementSet(Dataset):
|
||||
def __init__(self,
|
||||
path,
|
||||
frames_per_sample=100,
|
||||
no_pitch_value=256,
|
||||
preemph=0.85,
|
||||
skip=91,
|
||||
acorr_radius=2,
|
||||
pitch_hangover=8,
|
||||
num_bands_clean_spec=64,
|
||||
num_bands_noisy_spec=18,
|
||||
noisy_spec_scale='opus',
|
||||
noisy_apply_dct=True,
|
||||
add_offset=False,
|
||||
add_double_lag_acorr=False
|
||||
):
|
||||
|
||||
assert frames_per_sample % 4 == 0
|
||||
|
||||
self.frame_size = 80
|
||||
self.frames_per_sample = frames_per_sample
|
||||
self.no_pitch_value = no_pitch_value
|
||||
self.preemph = preemph
|
||||
self.skip = skip
|
||||
self.acorr_radius = acorr_radius
|
||||
self.pitch_hangover = pitch_hangover
|
||||
self.num_bands_clean_spec = num_bands_clean_spec
|
||||
self.num_bands_noisy_spec = num_bands_noisy_spec
|
||||
self.noisy_spec_scale = noisy_spec_scale
|
||||
self.add_double_lag_acorr = add_double_lag_acorr
|
||||
|
||||
self.lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16)
|
||||
self.ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5)
|
||||
self.periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
|
||||
self.gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
|
||||
self.num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32)
|
||||
self.num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32)
|
||||
self.offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32)
|
||||
|
||||
self.clean_signal = np.fromfile(os.path.join(path, 'clean_hp.s16'), dtype=np.int16)
|
||||
self.coded_signal = np.fromfile(os.path.join(path, 'coded.s16'), dtype=np.int16)
|
||||
|
||||
self.create_features = silk_feature_factory(no_pitch_value,
|
||||
acorr_radius,
|
||||
pitch_hangover,
|
||||
num_bands_clean_spec,
|
||||
num_bands_noisy_spec,
|
||||
noisy_spec_scale,
|
||||
noisy_apply_dct,
|
||||
add_offset,
|
||||
add_double_lag_acorr)
|
||||
|
||||
self.history_len = 700 if add_double_lag_acorr else 350
|
||||
# discard some frames to have enough signal history
|
||||
self.skip_frames = 4 * ((skip + self.history_len + 319) // 320 + 2)
|
||||
|
||||
num_frames = self.clean_signal.shape[0] // 80 - self.skip_frames
|
||||
|
||||
self.len = num_frames // frames_per_sample
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
frame_start = self.frames_per_sample * index + self.skip_frames
|
||||
frame_stop = frame_start + self.frames_per_sample
|
||||
|
||||
signal_start = frame_start * self.frame_size - self.skip
|
||||
signal_stop = frame_stop * self.frame_size - self.skip
|
||||
|
||||
clean_signal = self.clean_signal[signal_start : signal_stop].astype(np.float32) / 2**15
|
||||
coded_signal = self.coded_signal[signal_start : signal_stop].astype(np.float32) / 2**15
|
||||
|
||||
coded_signal_history = self.coded_signal[signal_start - self.history_len : signal_start].astype(np.float32) / 2**15
|
||||
|
||||
features, periods = self.create_features(
|
||||
coded_signal,
|
||||
coded_signal_history,
|
||||
self.lpcs[frame_start : frame_stop],
|
||||
self.gains[frame_start : frame_stop],
|
||||
self.ltps[frame_start : frame_stop],
|
||||
self.periods[frame_start : frame_stop],
|
||||
self.offsets[frame_start : frame_stop]
|
||||
)
|
||||
|
||||
if self.preemph > 0:
|
||||
clean_signal[1:] -= self.preemph * clean_signal[: -1]
|
||||
coded_signal[1:] -= self.preemph * coded_signal[: -1]
|
||||
|
||||
num_bits = np.repeat(self.num_bits[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
|
||||
num_bits_smooth = np.repeat(self.num_bits_smooth[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
|
||||
|
||||
numbits = np.concatenate((num_bits, num_bits_smooth), axis=-1)
|
||||
|
||||
return {
|
||||
'features' : features,
|
||||
'periods' : periods.astype(np.int64),
|
||||
'target' : clean_signal.astype(np.float32),
|
||||
'signals' : coded_signal.reshape(-1, 1).astype(np.float32),
|
||||
'numbits' : numbits.astype(np.float32)
|
||||
}
|
101
dnn/torch/osce/engine/engine.py
Normal file
101
dnn/torch/osce/engine/engine.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
import torch
|
||||
from tqdm import tqdm
|
||||
import sys
|
||||
|
||||
def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
|
||||
|
||||
model.to(device)
|
||||
model.train()
|
||||
|
||||
running_loss = 0
|
||||
previous_running_loss = 0
|
||||
|
||||
|
||||
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||
|
||||
for i, batch in enumerate(tepoch):
|
||||
|
||||
# set gradients to zero
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
# push batch to device
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device)
|
||||
|
||||
target = batch['target']
|
||||
|
||||
# calculate model output
|
||||
output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
|
||||
|
||||
# calculate loss
|
||||
if isinstance(output, list):
|
||||
loss = torch.zeros(1, device=device)
|
||||
for y in output:
|
||||
loss = loss + criterion(target, y.squeeze(1))
|
||||
loss = loss / len(output)
|
||||
else:
|
||||
loss = criterion(target, output.squeeze(1))
|
||||
|
||||
# calculate gradients
|
||||
loss.backward()
|
||||
|
||||
# update weights
|
||||
optimizer.step()
|
||||
|
||||
# update learning rate
|
||||
scheduler.step()
|
||||
|
||||
# update running loss
|
||||
running_loss += float(loss.cpu())
|
||||
|
||||
# update status bar
|
||||
if i % log_interval == 0:
|
||||
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
|
||||
previous_running_loss = running_loss
|
||||
|
||||
|
||||
running_loss /= len(dataloader)
|
||||
|
||||
return running_loss
|
||||
|
||||
def evaluate(model, criterion, dataloader, device, log_interval=10):
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
running_loss = 0
|
||||
previous_running_loss = 0
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||
|
||||
for i, batch in enumerate(tepoch):
|
||||
|
||||
|
||||
|
||||
# push batch to device
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device)
|
||||
|
||||
target = batch['target']
|
||||
|
||||
# calculate model output
|
||||
output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
|
||||
|
||||
# calculate loss
|
||||
loss = criterion(target, output.squeeze(1))
|
||||
|
||||
# update running loss
|
||||
running_loss += float(loss.cpu())
|
||||
|
||||
# update status bar
|
||||
if i % log_interval == 0:
|
||||
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
|
||||
previous_running_loss = running_loss
|
||||
|
||||
|
||||
running_loss /= len(dataloader)
|
||||
|
||||
return running_loss
|
277
dnn/torch/osce/losses/stft_loss.py
Normal file
277
dnn/torch/osce/losses/stft_loss.py
Normal file
|
@ -0,0 +1,277 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
"""STFT-based Loss modules."""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
|
||||
|
||||
def get_window(win_name, win_length, *args, **kwargs):
|
||||
window_dict = {
|
||||
'bartlett_window' : torch.bartlett_window,
|
||||
'blackman_window' : torch.blackman_window,
|
||||
'hamming_window' : torch.hamming_window,
|
||||
'hann_window' : torch.hann_window,
|
||||
'kaiser_window' : torch.kaiser_window
|
||||
}
|
||||
|
||||
if not win_name in window_dict:
|
||||
raise ValueError()
|
||||
|
||||
return window_dict[win_name](win_length, *args, **kwargs)
|
||||
|
||||
|
||||
def stft(x, fft_size, hop_size, win_length, window):
|
||||
"""Perform STFT and convert to magnitude spectrogram.
|
||||
Args:
|
||||
x (Tensor): Input signal tensor (B, T).
|
||||
fft_size (int): FFT size.
|
||||
hop_size (int): Hop size.
|
||||
win_length (int): Window length.
|
||||
window (str): Window function type.
|
||||
Returns:
|
||||
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
||||
"""
|
||||
|
||||
win = get_window(window, win_length).to(x.device)
|
||||
x_stft = torch.stft(x, fft_size, hop_size, win_length, win, return_complex=True)
|
||||
|
||||
|
||||
return torch.clamp(torch.abs(x_stft), min=1e-7)
|
||||
|
||||
def spectral_convergence_loss(Y_true, Y_pred):
|
||||
dims=list(range(1, len(Y_pred.shape)))
|
||||
return torch.mean(torch.norm(torch.abs(Y_true) - torch.abs(Y_pred), p="fro", dim=dims) / (torch.norm(Y_pred, p="fro", dim=dims) + 1e-6))
|
||||
|
||||
|
||||
def log_magnitude_loss(Y_true, Y_pred):
|
||||
Y_true_log_abs = torch.log(torch.abs(Y_true) + 1e-15)
|
||||
Y_pred_log_abs = torch.log(torch.abs(Y_pred) + 1e-15)
|
||||
|
||||
return torch.mean(torch.abs(Y_true_log_abs - Y_pred_log_abs))
|
||||
|
||||
def spectral_xcorr_loss(Y_true, Y_pred):
|
||||
Y_true = Y_true.abs()
|
||||
Y_pred = Y_pred.abs()
|
||||
dims=list(range(1, len(Y_pred.shape)))
|
||||
xcorr = torch.sum(Y_true * Y_pred, dim=dims) / torch.sqrt(torch.sum(Y_true ** 2, dim=dims) * torch.sum(Y_pred ** 2, dim=dims) + 1e-9)
|
||||
|
||||
return 1 - xcorr.mean()
|
||||
|
||||
|
||||
|
||||
class MRLogMelLoss(nn.Module):
|
||||
def __init__(self,
|
||||
fft_sizes=[512, 256, 128, 64],
|
||||
overlap=0.5,
|
||||
fs=16000,
|
||||
n_mels=18
|
||||
):
|
||||
|
||||
self.fft_sizes = fft_sizes
|
||||
self.overlap = overlap
|
||||
self.fs = fs
|
||||
self.n_mels = n_mels
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.mel_specs = []
|
||||
for fft_size in fft_sizes:
|
||||
hop_size = int(round(fft_size * (1 - self.overlap)))
|
||||
|
||||
n_mels = self.n_mels
|
||||
if fft_size < 128:
|
||||
n_mels //= 2
|
||||
|
||||
self.mel_specs.append(torchaudio.transforms.MelSpectrogram(fs, fft_size, hop_length=hop_size, n_mels=n_mels))
|
||||
|
||||
for i, mel_spec in enumerate(self.mel_specs):
|
||||
self.add_module(f'mel_spec_{i+1}', mel_spec)
|
||||
|
||||
def forward(self, y_true, y_pred):
|
||||
|
||||
loss = torch.zeros(1, device=y_true.device)
|
||||
|
||||
for mel_spec in self.mel_specs:
|
||||
Y_true = mel_spec(y_true)
|
||||
Y_pred = mel_spec(y_pred)
|
||||
loss = loss + log_magnitude_loss(Y_true, Y_pred)
|
||||
|
||||
loss = loss / len(self.mel_specs)
|
||||
|
||||
return loss
|
||||
|
||||
def create_weight_matrix(num_bins, bins_per_band=10):
|
||||
m = torch.zeros((num_bins, num_bins), dtype=torch.float32)
|
||||
|
||||
r0 = bins_per_band // 2
|
||||
r1 = bins_per_band - r0
|
||||
|
||||
for i in range(num_bins):
|
||||
i0 = max(i - r0, 0)
|
||||
j0 = min(i + r1, num_bins)
|
||||
|
||||
m[i, i0: j0] += 1
|
||||
|
||||
if i < r0:
|
||||
m[i, :r0 - i] += 1
|
||||
|
||||
if i > num_bins - r1:
|
||||
m[i, num_bins - r1 - i:] += 1
|
||||
|
||||
return m / bins_per_band
|
||||
|
||||
def weighted_spectral_convergence(Y_true, Y_pred, w):
|
||||
|
||||
# calculate sfm based weights
|
||||
logY = torch.log(torch.abs(Y_true) + 1e-9)
|
||||
Y = torch.abs(Y_true)
|
||||
|
||||
avg_logY = torch.matmul(logY.transpose(1, 2), w)
|
||||
avg_Y = torch.matmul(Y.transpose(1, 2), w)
|
||||
|
||||
sfm = torch.exp(avg_logY) / (avg_Y + 1e-9)
|
||||
|
||||
weight = (torch.relu(1 - sfm) ** .5).transpose(1, 2)
|
||||
|
||||
loss = torch.mean(
|
||||
torch.mean(weight * torch.abs(torch.abs(Y_true) - torch.abs(Y_pred)), dim=[1, 2])
|
||||
/ (torch.mean( weight * torch.abs(Y_true), dim=[1, 2]) + 1e-9)
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def gen_filterbank(N, Fs=16000):
|
||||
in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
|
||||
out_freq = (np.arange(N, dtype='float32')/N*Fs/2)[:,None]
|
||||
#ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
|
||||
ERB_N = 24.7 + .108*in_freq
|
||||
delta = np.abs(in_freq-out_freq)/ERB_N
|
||||
center = (delta<.5).astype('float32')
|
||||
R = -12*center*delta**2 + (1-center)*(3-12*delta)
|
||||
RE = 10.**(R/10.)
|
||||
norm = np.sum(RE, axis=1)
|
||||
RE = RE/norm[:, np.newaxis]
|
||||
return torch.from_numpy(RE)
|
||||
|
||||
def smooth_log_mag(Y_true, Y_pred, filterbank):
|
||||
Y_true_smooth = torch.matmul(filterbank, torch.abs(Y_true))
|
||||
Y_pred_smooth = torch.matmul(filterbank, torch.abs(Y_pred))
|
||||
|
||||
loss = torch.abs(
|
||||
torch.log(Y_true_smooth + 1e-9) - torch.log(Y_pred_smooth + 1e-9)
|
||||
)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
return loss
|
||||
|
||||
class MRSTFTLoss(nn.Module):
|
||||
def __init__(self,
|
||||
fft_sizes=[2048, 1024, 512, 256, 128, 64],
|
||||
overlap=0.5,
|
||||
window='hann_window',
|
||||
fs=16000,
|
||||
log_mag_weight=1,
|
||||
sc_weight=0,
|
||||
wsc_weight=0,
|
||||
smooth_log_mag_weight=0,
|
||||
sxcorr_weight=0):
|
||||
super().__init__()
|
||||
|
||||
self.fft_sizes = fft_sizes
|
||||
self.overlap = overlap
|
||||
self.window = window
|
||||
self.log_mag_weight = log_mag_weight
|
||||
self.sc_weight = sc_weight
|
||||
self.wsc_weight = wsc_weight
|
||||
self.smooth_log_mag_weight = smooth_log_mag_weight
|
||||
self.sxcorr_weight = sxcorr_weight
|
||||
self.fs = fs
|
||||
|
||||
# weights for SFM weighted spectral convergence loss
|
||||
self.wsc_weights = torch.nn.ParameterDict()
|
||||
for fft_size in fft_sizes:
|
||||
width = min(11, int(1000 * fft_size / self.fs + .5))
|
||||
width += width % 2
|
||||
self.wsc_weights[str(fft_size)] = torch.nn.Parameter(
|
||||
create_weight_matrix(fft_size // 2 + 1, width),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
# filterbanks for smooth log magnitude loss
|
||||
self.filterbanks = torch.nn.ParameterDict()
|
||||
for fft_size in fft_sizes:
|
||||
self.filterbanks[str(fft_size)] = torch.nn.Parameter(
|
||||
gen_filterbank(fft_size//2),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
|
||||
def __call__(self, y_true, y_pred):
|
||||
|
||||
|
||||
lm_loss = torch.zeros(1, device=y_true.device)
|
||||
sc_loss = torch.zeros(1, device=y_true.device)
|
||||
wsc_loss = torch.zeros(1, device=y_true.device)
|
||||
slm_loss = torch.zeros(1, device=y_true.device)
|
||||
sxcorr_loss = torch.zeros(1, device=y_true.device)
|
||||
|
||||
for fft_size in self.fft_sizes:
|
||||
hop_size = int(round(fft_size * (1 - self.overlap)))
|
||||
win_size = fft_size
|
||||
|
||||
Y_true = stft(y_true, fft_size, hop_size, win_size, self.window)
|
||||
Y_pred = stft(y_pred, fft_size, hop_size, win_size, self.window)
|
||||
|
||||
if self.log_mag_weight > 0:
|
||||
lm_loss = lm_loss + log_magnitude_loss(Y_true, Y_pred)
|
||||
|
||||
if self.sc_weight > 0:
|
||||
sc_loss = sc_loss + spectral_convergence_loss(Y_true, Y_pred)
|
||||
|
||||
if self.wsc_weight > 0:
|
||||
wsc_loss = wsc_loss + weighted_spectral_convergence(Y_true, Y_pred, self.wsc_weights[str(fft_size)])
|
||||
|
||||
if self.smooth_log_mag_weight > 0:
|
||||
slm_loss = slm_loss + smooth_log_mag(Y_true, Y_pred, self.filterbanks[str(fft_size)])
|
||||
|
||||
if self.sxcorr_weight > 0:
|
||||
sxcorr_loss = sxcorr_loss + spectral_xcorr_loss(Y_true, Y_pred)
|
||||
|
||||
|
||||
total_loss = (self.log_mag_weight * lm_loss + self.sc_weight * sc_loss
|
||||
+ self.wsc_weight * wsc_loss + self.smooth_log_mag_weight * slm_loss
|
||||
+ self.sxcorr_weight * sxcorr_loss) / len(self.fft_sizes)
|
||||
|
||||
return total_loss
|
56
dnn/torch/osce/make_default_setup.py
Normal file
56
dnn/torch/osce/make_default_setup.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import yaml
|
||||
|
||||
from utils.templates import setup_dict
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('name', type=str, help='name of default setup file')
|
||||
parser.add_argument('--model', choices=['lace'], help='model name', default='lace')
|
||||
parser.add_argument('--path2dataset', type=str, help='dataset path', default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
setup = setup_dict[args.model]
|
||||
|
||||
# update dataset if given
|
||||
if type(args.path2dataset) != type(None):
|
||||
setup['dataset'] = args.path2dataset
|
||||
|
||||
name = args.name
|
||||
if not name.endswith('.yml'):
|
||||
name += '.yml'
|
||||
|
||||
if __name__ == '__main__':
|
||||
with open(name, 'w') as f:
|
||||
f.write(yaml.dump(setup))
|
36
dnn/torch/osce/models/__init__.py
Normal file
36
dnn/torch/osce/models/__init__.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
from .lace import LACE
|
||||
|
||||
|
||||
|
||||
model_dict = {
|
||||
'lace': LACE
|
||||
}
|
176
dnn/torch/osce/models/lace.py
Normal file
176
dnn/torch/osce/models/lace.py
Normal file
|
@ -0,0 +1,176 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
|
||||
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
|
||||
|
||||
from models.nns_base import NNSBase
|
||||
from models.silk_feature_net_pl import SilkFeatureNetPL
|
||||
from models.silk_feature_net import SilkFeatureNet
|
||||
from .scale_embedding import ScaleEmbedding
|
||||
|
||||
class LACE(NNSBase):
|
||||
""" Linear-Adaptive Coding Enhancer """
|
||||
FRAME_SIZE=80
|
||||
|
||||
def __init__(self,
|
||||
num_features=47,
|
||||
pitch_embedding_dim=64,
|
||||
cond_dim=256,
|
||||
pitch_max=257,
|
||||
kernel_size=15,
|
||||
preemph=0.85,
|
||||
skip=91,
|
||||
comb_gain_limit_db=-6,
|
||||
global_gain_limits_db=[-6, 6],
|
||||
conv_gain_limits_db=[-6, 6],
|
||||
numbits_range=[50, 650],
|
||||
numbits_embedding_dim=8,
|
||||
hidden_feature_dim=64,
|
||||
partial_lookahead=True,
|
||||
norm_p=2):
|
||||
|
||||
super().__init__(skip=skip, preemph=preemph)
|
||||
|
||||
|
||||
self.num_features = num_features
|
||||
self.cond_dim = cond_dim
|
||||
self.pitch_max = pitch_max
|
||||
self.pitch_embedding_dim = pitch_embedding_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.preemph = preemph
|
||||
self.skip = skip
|
||||
self.numbits_range = numbits_range
|
||||
self.numbits_embedding_dim = numbits_embedding_dim
|
||||
self.hidden_feature_dim = hidden_feature_dim
|
||||
self.partial_lookahead = partial_lookahead
|
||||
|
||||
# pitch embedding
|
||||
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
|
||||
|
||||
# numbits embedding
|
||||
self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)
|
||||
|
||||
# feature net
|
||||
if partial_lookahead:
|
||||
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim)
|
||||
else:
|
||||
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
|
||||
|
||||
# comb filters
|
||||
left_pad = self.kernel_size // 2
|
||||
right_pad = self.kernel_size - 1 - left_pad
|
||||
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
|
||||
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
|
||||
|
||||
# spectral shaping
|
||||
self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
|
||||
def flop_count(self, rate=16000, verbose=False):
|
||||
|
||||
frame_rate = rate / self.FRAME_SIZE
|
||||
|
||||
# feature net
|
||||
feature_net_flops = self.feature_net.flop_count(frame_rate)
|
||||
comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
|
||||
af_flops = self.af1.flop_count(rate)
|
||||
|
||||
if verbose:
|
||||
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
|
||||
print(f"comb filters: {comb_flops / 1e6} MFLOPS")
|
||||
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
|
||||
|
||||
return feature_net_flops + comb_flops + af_flops
|
||||
|
||||
def forward(self, x, features, periods, numbits, debug=False):
|
||||
|
||||
periods = periods.squeeze(-1)
|
||||
pitch_embedding = self.pitch_embedding(periods)
|
||||
numbits_embedding = self.numbits_embedding(numbits).flatten(2)
|
||||
|
||||
full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
|
||||
cf = self.feature_net(full_features)
|
||||
|
||||
y = self.cf1(x, cf, periods, debug=debug)
|
||||
|
||||
y = self.cf2(y, cf, periods, debug=debug)
|
||||
|
||||
y = self.af1(y, cf, debug=debug)
|
||||
|
||||
return y
|
||||
|
||||
def get_impulse_responses(self, features, periods, numbits):
|
||||
""" generates impoulse responses on frame centers (input without batch dimension) """
|
||||
|
||||
num_frames = features.size(0)
|
||||
batch_size = 32
|
||||
max_len = 2 * (self.pitch_max + self.kernel_size) + 10
|
||||
|
||||
# spread out some pulses
|
||||
x = np.zeros((batch_size, 1, num_frames * self.FRAME_SIZE))
|
||||
for b in range(batch_size):
|
||||
x[b, :, self.FRAME_SIZE // 2 + b * self.FRAME_SIZE :: batch_size * self.FRAME_SIZE] = 1
|
||||
|
||||
# prepare input
|
||||
x = torch.from_numpy(x).float().to(features.device)
|
||||
features = torch.repeat_interleave(features.unsqueeze(0), batch_size, 0)
|
||||
periods = torch.repeat_interleave(periods.unsqueeze(0), batch_size, 0)
|
||||
numbits = torch.repeat_interleave(numbits.unsqueeze(0), batch_size, 0)
|
||||
|
||||
# run network
|
||||
with torch.no_grad():
|
||||
periods = periods.squeeze(-1)
|
||||
pitch_embedding = self.pitch_embedding(periods)
|
||||
numbits_embedding = self.numbits_embedding(numbits).flatten(2)
|
||||
full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
|
||||
cf = self.feature_net(full_features)
|
||||
y = self.cf1(x, cf, periods, debug=False)
|
||||
y = self.cf2(y, cf, periods, debug=False)
|
||||
y = self.af1(y, cf, debug=False)
|
||||
|
||||
# collect responses
|
||||
y = y.detach().squeeze().cpu().numpy()
|
||||
cut_frames = (max_len + self.FRAME_SIZE - 1) // self.FRAME_SIZE
|
||||
num_responses = num_frames - cut_frames
|
||||
responses = np.zeros((num_responses, max_len))
|
||||
|
||||
for i in range(num_responses):
|
||||
b = i % batch_size
|
||||
start = self.FRAME_SIZE // 2 + i * self.FRAME_SIZE
|
||||
stop = start + max_len
|
||||
|
||||
responses[i, :] = y[b, start:stop]
|
||||
|
||||
return responses
|
69
dnn/torch/osce/models/nns_base.py
Normal file
69
dnn/torch/osce/models/nns_base.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class NNSBase(nn.Module):
|
||||
|
||||
def __init__(self, skip=91, preemph=0.85):
|
||||
super().__init__()
|
||||
|
||||
self.skip = skip
|
||||
self.preemph = preemph
|
||||
|
||||
def process(self, sig, features, periods, numbits, debug=False):
|
||||
|
||||
self.eval()
|
||||
has_numbits = 'numbits' in self.forward.__code__.co_varnames
|
||||
device = next(iter(self.parameters())).device
|
||||
with torch.no_grad():
|
||||
|
||||
# run model
|
||||
x = sig.view(1, 1, -1).to(device)
|
||||
f = features.unsqueeze(0).to(device)
|
||||
p = periods.unsqueeze(0).to(device)
|
||||
n = numbits.unsqueeze(0).to(device)
|
||||
|
||||
if has_numbits:
|
||||
y = self.forward(x, f, p, n, debug=debug).squeeze()
|
||||
else:
|
||||
y = self.forward(x, f, p, debug=debug).squeeze()
|
||||
|
||||
# deemphasis
|
||||
if self.preemph > 0:
|
||||
for i in range(len(y) - 1):
|
||||
y[i + 1] += self.preemph * y[i]
|
||||
|
||||
# delay compensation
|
||||
y = torch.cat((y[self.skip:], torch.zeros(self.skip, dtype=y.dtype, device=y.device)))
|
||||
out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short()
|
||||
|
||||
return out
|
68
dnn/torch/osce/models/scale_embedding.py
Normal file
68
dnn/torch/osce/models/scale_embedding.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import math as m
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ScaleEmbedding(nn.Module):
|
||||
def __init__(self,
|
||||
dim,
|
||||
min_val,
|
||||
max_val,
|
||||
logscale=False):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if min_val >= max_val:
|
||||
raise ValueError('min_val must be smaller than max_val')
|
||||
|
||||
if min_val <= 0 and logscale:
|
||||
raise ValueError('min_val must be positive when logscale is true')
|
||||
|
||||
self.dim = dim
|
||||
self.logscale = logscale
|
||||
self.min_val = min_val
|
||||
self.max_val = max_val
|
||||
|
||||
if logscale:
|
||||
self.min_val = m.log(self.min_val)
|
||||
self.max_val = m.log(self.max_val)
|
||||
|
||||
|
||||
self.offset = (self.min_val + self.max_val) / 2
|
||||
self.scale_factors = nn.Parameter(
|
||||
torch.arange(1, dim+1, dtype=torch.float32) * torch.pi / (self.max_val - self.min_val)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.logscale: x = torch.log(x)
|
||||
x = torch.clip(x, self.min_val, self.max_val) - self.offset
|
||||
return torch.sin(x.unsqueeze(-1) * self.scale_factors - 0.5)
|
86
dnn/torch/osce/models/silk_feature_net.py
Normal file
86
dnn/torch/osce/models/silk_feature_net.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils.complexity import _conv1d_flop_count
|
||||
|
||||
class SilkFeatureNet(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
feature_dim=47,
|
||||
num_channels=256,
|
||||
lookahead=False):
|
||||
|
||||
super(SilkFeatureNet, self).__init__()
|
||||
|
||||
self.feature_dim = feature_dim
|
||||
self.num_channels = num_channels
|
||||
self.lookahead = lookahead
|
||||
|
||||
self.conv1 = nn.Conv1d(feature_dim, num_channels, 3)
|
||||
self.conv2 = nn.Conv1d(num_channels, num_channels, 3, dilation=2)
|
||||
|
||||
self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
|
||||
|
||||
def flop_count(self, rate=200):
|
||||
count = 0
|
||||
for conv in self.conv1, self.conv2:
|
||||
count += _conv1d_flop_count(conv, rate)
|
||||
|
||||
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def forward(self, features, state=None):
|
||||
""" features shape: (batch_size, num_frames, feature_dim) """
|
||||
|
||||
batch_size = features.size(0)
|
||||
|
||||
if state is None:
|
||||
state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
|
||||
|
||||
|
||||
features = features.permute(0, 2, 1)
|
||||
if self.lookahead:
|
||||
c = torch.tanh(self.conv1(F.pad(features, [1, 1])))
|
||||
c = torch.tanh(self.conv2(F.pad(c, [2, 2])))
|
||||
else:
|
||||
c = torch.tanh(self.conv1(F.pad(features, [2, 0])))
|
||||
c = torch.tanh(self.conv2(F.pad(c, [4, 0])))
|
||||
|
||||
c = c.permute(0, 2, 1)
|
||||
|
||||
c, _ = self.gru(c, state)
|
||||
|
||||
return c
|
90
dnn/torch/osce/models/silk_feature_net_pl.py
Normal file
90
dnn/torch/osce/models/silk_feature_net_pl.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils.complexity import _conv1d_flop_count
|
||||
|
||||
class SilkFeatureNetPL(nn.Module):
|
||||
""" feature net with partial lookahead """
|
||||
def __init__(self,
|
||||
feature_dim=47,
|
||||
num_channels=256,
|
||||
hidden_feature_dim=64):
|
||||
|
||||
super(SilkFeatureNetPL, self).__init__()
|
||||
|
||||
self.feature_dim = feature_dim
|
||||
self.num_channels = num_channels
|
||||
self.hidden_feature_dim = hidden_feature_dim
|
||||
|
||||
self.conv1 = nn.Conv1d(feature_dim, self.hidden_feature_dim, 1)
|
||||
self.conv2 = nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2)
|
||||
self.tconv = nn.ConvTranspose1d(num_channels, num_channels, 4, 4)
|
||||
|
||||
self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
|
||||
|
||||
def flop_count(self, rate=200):
|
||||
count = 0
|
||||
for conv in self.conv1, self.conv2, self.tconv:
|
||||
count += _conv1d_flop_count(conv, rate)
|
||||
|
||||
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def forward(self, features, state=None):
|
||||
""" features shape: (batch_size, num_frames, feature_dim) """
|
||||
|
||||
batch_size = features.size(0)
|
||||
num_frames = features.size(1)
|
||||
|
||||
if state is None:
|
||||
state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
|
||||
|
||||
features = features.permute(0, 2, 1)
|
||||
# dimensionality reduction
|
||||
c = torch.tanh(self.conv1(features))
|
||||
|
||||
# frame accumulation
|
||||
c = c.permute(0, 2, 1)
|
||||
c = c.reshape(batch_size, num_frames // 4, -1).permute(0, 2, 1)
|
||||
c = torch.tanh(self.conv2(F.pad(c, [1, 0])))
|
||||
|
||||
# upsampling
|
||||
c = self.tconv(c)
|
||||
c = c.permute(0, 2, 1)
|
||||
|
||||
c, _ = self.gru(c, state)
|
||||
|
||||
return c
|
96
dnn/torch/osce/test_model.py
Normal file
96
dnn/torch/osce/test_model.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from scipy.io import wavfile
|
||||
|
||||
|
||||
from models import model_dict
|
||||
from utils.silk_features import load_inference_data
|
||||
from utils import endoscopy
|
||||
|
||||
debug = False
|
||||
if debug:
|
||||
args = type('dummy', (object,),
|
||||
{
|
||||
'input' : 'testitems/all_0_orig.se',
|
||||
'checkpoint' : 'testout/checkpoints/checkpoint_epoch_5.pth',
|
||||
'output' : 'out.wav',
|
||||
})()
|
||||
else:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('input', type=str, help='path to folder with features and signals')
|
||||
parser.add_argument('checkpoint', type=str, help='checkpoint file')
|
||||
parser.add_argument('output', type=str, help='output file')
|
||||
parser.add_argument('--debug', action='store_true', help='enables debug output')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
torch.set_num_threads(2)
|
||||
|
||||
input_folder = args.input
|
||||
checkpoint_file = args.checkpoint
|
||||
|
||||
|
||||
output_file = args.output
|
||||
if not output_file.endswith('.wav'):
|
||||
output_file += '.wav'
|
||||
|
||||
checkpoint = torch.load(checkpoint_file, map_location="cpu")
|
||||
|
||||
# check model
|
||||
if not 'name' in checkpoint['setup']['model']:
|
||||
print(f'warning: did not find model name entry in setup, using pitchpostfilter per default')
|
||||
model_name = 'pitchpostfilter'
|
||||
else:
|
||||
model_name = checkpoint['setup']['model']['name']
|
||||
|
||||
model = model_dict[model_name](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
|
||||
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
# generate model input
|
||||
setup = checkpoint['setup']
|
||||
signal, features, periods, numbits = load_inference_data(input_folder, **setup['data'])
|
||||
|
||||
if args.debug:
|
||||
endoscopy.init()
|
||||
|
||||
output = model.process(signal, features, periods, numbits, debug=args.debug)
|
||||
|
||||
wavfile.write(output_file, 16000, output.cpu().numpy())
|
||||
|
||||
if args.debug:
|
||||
endoscopy.close()
|
297
dnn/torch/osce/train_model.py
Normal file
297
dnn/torch/osce/train_model.py
Normal file
|
@ -0,0 +1,297 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import yaml
|
||||
|
||||
try:
|
||||
import git
|
||||
has_git = True
|
||||
except:
|
||||
has_git = False
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
import numpy as np
|
||||
|
||||
from scipy.io import wavfile
|
||||
|
||||
import pesq
|
||||
|
||||
from data import SilkEnhancementSet
|
||||
from models import model_dict
|
||||
from engine.engine import train_one_epoch, evaluate
|
||||
|
||||
|
||||
from utils.silk_features import load_inference_data
|
||||
from utils.misc import count_parameters
|
||||
|
||||
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('setup', type=str, help='setup yaml file')
|
||||
parser.add_argument('output', type=str, help='output path')
|
||||
parser.add_argument('--device', type=str, help='compute device', default=None)
|
||||
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
|
||||
parser.add_argument('--testdata', type=str, help='path to features and signal for testing', default=None)
|
||||
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
torch.set_num_threads(4)
|
||||
|
||||
with open(args.setup, 'r') as f:
|
||||
setup = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
checkpoint_prefix = 'checkpoint'
|
||||
output_prefix = 'output'
|
||||
setup_name = 'setup.yml'
|
||||
output_file='out.txt'
|
||||
|
||||
|
||||
# check model
|
||||
if not 'name' in setup['model']:
|
||||
print(f'warning: did not find model entry in setup, using default PitchPostFilter')
|
||||
model_name = 'pitchpostfilter'
|
||||
else:
|
||||
model_name = setup['model']['name']
|
||||
|
||||
# prepare output folder
|
||||
if os.path.exists(args.output):
|
||||
print("warning: output folder exists")
|
||||
|
||||
reply = input('continue? (y/n): ')
|
||||
while reply not in {'y', 'n'}:
|
||||
reply = input('continue? (y/n): ')
|
||||
|
||||
if reply == 'n':
|
||||
os._exit()
|
||||
else:
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
checkpoint_dir = os.path.join(args.output, 'checkpoints')
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# add repo info to setup
|
||||
if has_git:
|
||||
working_dir = os.path.split(__file__)[0]
|
||||
try:
|
||||
repo = git.Repo(working_dir)
|
||||
setup['repo'] = dict()
|
||||
hash = repo.head.object.hexsha
|
||||
urls = list(repo.remote().urls)
|
||||
is_dirty = repo.is_dirty()
|
||||
|
||||
if is_dirty:
|
||||
print("warning: repo is dirty")
|
||||
|
||||
setup['repo']['hash'] = hash
|
||||
setup['repo']['urls'] = urls
|
||||
setup['repo']['dirty'] = is_dirty
|
||||
except:
|
||||
has_git = False
|
||||
|
||||
# dump setup
|
||||
with open(os.path.join(args.output, setup_name), 'w') as f:
|
||||
yaml.dump(setup, f)
|
||||
|
||||
ref = None
|
||||
if args.testdata is not None:
|
||||
|
||||
testsignal, features, periods, numbits = load_inference_data(args.testdata, **setup['data'])
|
||||
|
||||
inference_test = True
|
||||
inference_folder = os.path.join(args.output, 'inference_test')
|
||||
os.makedirs(os.path.join(args.output, 'inference_test'), exist_ok=True)
|
||||
|
||||
try:
|
||||
ref = np.fromfile(os.path.join(args.testdata, 'clean.s16'), dtype=np.int16)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
inference_test = False
|
||||
|
||||
# training parameters
|
||||
batch_size = setup['training']['batch_size']
|
||||
epochs = setup['training']['epochs']
|
||||
lr = setup['training']['lr']
|
||||
lr_decay_factor = setup['training']['lr_decay_factor']
|
||||
|
||||
# load training dataset
|
||||
data_config = setup['data']
|
||||
data = SilkEnhancementSet(setup['dataset'], **data_config)
|
||||
|
||||
# load validation dataset if given
|
||||
if 'validation_dataset' in setup:
|
||||
validation_data = SilkEnhancementSet(setup['validation_dataset'], **data_config)
|
||||
|
||||
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=8)
|
||||
|
||||
run_validation = True
|
||||
else:
|
||||
run_validation = False
|
||||
|
||||
# create model
|
||||
model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
|
||||
|
||||
if args.initial_checkpoint is not None:
|
||||
print(f"loading state dict from {args.initial_checkpoint}...")
|
||||
chkpt = torch.load(args.initial_checkpoint, map_location='cpu')
|
||||
model.load_state_dict(chkpt['state_dict'])
|
||||
|
||||
# set compute device
|
||||
if type(args.device) == type(None):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
device = torch.device(args.device)
|
||||
|
||||
# push model to device
|
||||
model.to(device)
|
||||
|
||||
# dataloader
|
||||
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=8)
|
||||
|
||||
# optimizer is introduced to trainable parameters
|
||||
parameters = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.Adam(parameters, lr=lr)
|
||||
|
||||
# learning rate scheduler
|
||||
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
|
||||
|
||||
# loss
|
||||
w_l1 = setup['training']['loss']['w_l1']
|
||||
w_lm = setup['training']['loss']['w_lm']
|
||||
w_slm = setup['training']['loss']['w_slm']
|
||||
w_sc = setup['training']['loss']['w_sc']
|
||||
w_logmel = setup['training']['loss']['w_logmel']
|
||||
w_wsc = setup['training']['loss']['w_wsc']
|
||||
w_xcorr = setup['training']['loss']['w_xcorr']
|
||||
w_sxcorr = setup['training']['loss']['w_sxcorr']
|
||||
w_l2 = setup['training']['loss']['w_l2']
|
||||
|
||||
w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
|
||||
|
||||
stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
|
||||
logmelloss = MRLogMelLoss().to(device)
|
||||
|
||||
def xcorr_loss(y_true, y_pred):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
|
||||
loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
|
||||
|
||||
return torch.mean(loss)
|
||||
|
||||
def td_l2_norm(y_true, y_pred):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
|
||||
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def td_l1(y_true, y_pred, pow=0):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
|
||||
|
||||
return torch.mean(tmp)
|
||||
|
||||
def criterion(x, y):
|
||||
|
||||
return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
|
||||
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
|
||||
|
||||
|
||||
|
||||
# model checkpoint
|
||||
checkpoint = {
|
||||
'setup' : setup,
|
||||
'state_dict' : model.state_dict(),
|
||||
'loss' : -1
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
if not args.no_redirect:
|
||||
print(f"re-directing output to {os.path.join(args.output, output_file)}")
|
||||
sys.stdout = open(os.path.join(args.output, output_file), "w")
|
||||
|
||||
print("summary:")
|
||||
|
||||
print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
|
||||
if hasattr(model, 'flop_count'):
|
||||
print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS")
|
||||
|
||||
if ref is not None:
|
||||
noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16)
|
||||
initial_mos = pesq.pesq(16000, ref, noisy, mode='wb')
|
||||
print(f"initial MOS (PESQ): {initial_mos}")
|
||||
|
||||
best_loss = 1e9
|
||||
|
||||
for ep in range(1, epochs + 1):
|
||||
print(f"training epoch {ep}...")
|
||||
new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler)
|
||||
|
||||
|
||||
# save checkpoint
|
||||
checkpoint['state_dict'] = model.state_dict()
|
||||
checkpoint['loss'] = new_loss
|
||||
|
||||
if run_validation:
|
||||
print("running validation...")
|
||||
validation_loss = evaluate(model, criterion, validation_dataloader, device)
|
||||
checkpoint['validation_loss'] = validation_loss
|
||||
|
||||
if validation_loss < best_loss:
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth'))
|
||||
best_loss = validation_loss
|
||||
|
||||
if inference_test:
|
||||
print("running inference test...")
|
||||
out = model.process(testsignal, features, periods, numbits).cpu().numpy()
|
||||
wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
|
||||
if ref is not None:
|
||||
mos = pesq.pesq(16000, ref, out, mode='wb')
|
||||
print(f"MOS (PESQ): {mos}")
|
||||
|
||||
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
|
||||
|
||||
|
||||
print()
|
||||
|
||||
print('Done')
|
35
dnn/torch/osce/utils/complexity.py
Normal file
35
dnn/torch/osce/utils/complexity.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
def _conv1d_flop_count(layer, rate):
|
||||
return 2 * ((layer.in_channels + 1) * layer.out_channels * rate / layer.stride[0] ) * layer.kernel_size[0]
|
||||
|
||||
|
||||
def _dense_flop_count(layer, rate):
|
||||
return 2 * ((layer.in_features + 1) * layer.out_features * rate )
|
234
dnn/torch/osce/utils/endoscopy.py
Normal file
234
dnn/torch/osce/utils/endoscopy.py
Normal file
|
@ -0,0 +1,234 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
""" module for inspecting models during inference """
|
||||
|
||||
import os
|
||||
|
||||
import yaml
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.animation as animation
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
|
||||
_state = dict()
|
||||
_folder = 'endoscopy'
|
||||
|
||||
def get_gru_gates(gru, input, state):
|
||||
hidden_size = gru.hidden_size
|
||||
|
||||
direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
|
||||
recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
|
||||
|
||||
# reset gate
|
||||
start, stop = 0 * hidden_size, 1 * hidden_size
|
||||
reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
|
||||
|
||||
# update gate
|
||||
start, stop = 1 * hidden_size, 2 * hidden_size
|
||||
update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
|
||||
|
||||
# new gate
|
||||
start, stop = 2 * hidden_size, 3 * hidden_size
|
||||
new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop]))
|
||||
|
||||
return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
|
||||
|
||||
|
||||
def init(folder='endoscopy'):
|
||||
""" sets up output folder for endoscopy data """
|
||||
|
||||
global _folder
|
||||
_folder = folder
|
||||
|
||||
if not os.path.exists(folder):
|
||||
os.makedirs(folder)
|
||||
else:
|
||||
print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
|
||||
|
||||
def write_data(key, data, fs):
|
||||
""" appends data to previous data written under key """
|
||||
|
||||
global _state
|
||||
|
||||
# convert to numpy if torch.Tensor is given
|
||||
if isinstance(data, torch.Tensor):
|
||||
data = data.detach().numpy()
|
||||
|
||||
if not key in _state:
|
||||
_state[key] = {
|
||||
'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'),
|
||||
'fs' : fs,
|
||||
'dim' : tuple(data.shape),
|
||||
'dtype' : str(data.dtype)
|
||||
}
|
||||
|
||||
with open(os.path.join(_folder, key + '.yml'), 'w') as f:
|
||||
f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
|
||||
else:
|
||||
if _state[key]['fs'] != fs:
|
||||
raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
|
||||
if _state[key]['dtype'] != str(data.dtype):
|
||||
raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
|
||||
if _state[key]['dim'] != tuple(data.shape):
|
||||
raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
|
||||
|
||||
_state[key]['fid'].write(data.tobytes())
|
||||
|
||||
def close(folder='endoscopy'):
|
||||
""" clean up """
|
||||
for key in _state.keys():
|
||||
_state[key]['fid'].close()
|
||||
|
||||
|
||||
def read_data(folder='endoscopy'):
|
||||
""" retrieves written data as numpy arrays """
|
||||
|
||||
|
||||
keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
|
||||
|
||||
return_dict = dict()
|
||||
|
||||
for key in keys:
|
||||
with open(os.path.join(folder, key + '.yml'), 'r') as f:
|
||||
value = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
with open(os.path.join(folder, key + '.bin'), 'rb') as f:
|
||||
data = np.frombuffer(f.read(), dtype=value['dtype'])
|
||||
|
||||
value['data'] = data.reshape((-1,) + value['dim'])
|
||||
|
||||
return_dict[key] = value
|
||||
|
||||
return return_dict
|
||||
|
||||
def get_best_reshape(shape, target_ratio=1):
|
||||
""" calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
|
||||
|
||||
if len(shape) > 1:
|
||||
pixel_count = 1
|
||||
for s in shape:
|
||||
pixel_count *= s
|
||||
else:
|
||||
pixel_count = shape[0]
|
||||
|
||||
if pixel_count == 1:
|
||||
return (1,)
|
||||
|
||||
num_columns = int((pixel_count / target_ratio)**.5)
|
||||
|
||||
while (pixel_count % num_columns):
|
||||
num_columns -= 1
|
||||
|
||||
num_rows = pixel_count // num_columns
|
||||
|
||||
return (num_rows, num_columns)
|
||||
|
||||
def get_type_and_shape(shape):
|
||||
|
||||
# can happen if data is one dimensional
|
||||
if len(shape) == 0:
|
||||
shape = (1,)
|
||||
|
||||
# calculate pixel count
|
||||
if len(shape) > 1:
|
||||
pixel_count = 1
|
||||
for s in shape:
|
||||
pixel_count *= s
|
||||
else:
|
||||
pixel_count = shape[0]
|
||||
|
||||
if pixel_count == 1:
|
||||
return 'plot', (1, )
|
||||
|
||||
# stay with shape if already 2-dimensional
|
||||
if len(shape) == 2:
|
||||
if (shape[0] != pixel_count) or (shape[1] != pixel_count):
|
||||
return 'image', shape
|
||||
|
||||
return 'image', get_best_reshape(shape)
|
||||
|
||||
def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
|
||||
|
||||
# determine plot setup
|
||||
num_keys = len(data.keys())
|
||||
|
||||
num_rows = int((num_keys * 3/4) ** .5)
|
||||
|
||||
num_cols = (num_keys + num_rows - 1) // num_rows
|
||||
|
||||
fig, axs = plt.subplots(num_rows, num_cols)
|
||||
fig.set_size_inches(num_cols * 5, num_rows * 5)
|
||||
|
||||
display = dict()
|
||||
|
||||
fs_max = max([val['fs'] for val in data.values()])
|
||||
|
||||
num_samples = max([val['data'].shape[0] for val in data.values()])
|
||||
|
||||
keys = sorted(data.keys())
|
||||
|
||||
# inspect data
|
||||
for i, key in enumerate(keys):
|
||||
axs[i // num_cols, i % num_cols].title.set_text(key)
|
||||
|
||||
display[key] = dict()
|
||||
|
||||
display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
|
||||
display[key]['down_factor'] = data[key]['fs'] / fs_max
|
||||
|
||||
start_index = max(start_index, half_signal_window_length)
|
||||
while stop_index < 0:
|
||||
stop_index += num_samples
|
||||
|
||||
stop_index = min(stop_index, num_samples - half_signal_window_length)
|
||||
|
||||
# actual plotting
|
||||
frames = []
|
||||
for index in range(start_index, stop_index):
|
||||
ims = []
|
||||
for i, key in enumerate(keys):
|
||||
feature_index = int(round(index * display[key]['down_factor']))
|
||||
|
||||
if display[key]['type'] == 'plot':
|
||||
ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
|
||||
|
||||
elif display[key]['type'] == 'image':
|
||||
ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
|
||||
|
||||
frames.append(ims)
|
||||
|
||||
ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
|
||||
|
||||
if not filename.endswith('.mp4'):
|
||||
filename += '.mp4'
|
||||
|
||||
ani.save(filename)
|
236
dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py
Normal file
236
dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py
Normal file
|
@ -0,0 +1,236 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils.endoscopy import write_data
|
||||
|
||||
class LimitedAdaptiveComb1d(nn.Module):
|
||||
COUNTER = 1
|
||||
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
feature_dim,
|
||||
frame_size=160,
|
||||
overlap_size=40,
|
||||
use_bias=True,
|
||||
padding=None,
|
||||
max_lag=256,
|
||||
name=None,
|
||||
gain_limit_db=10,
|
||||
global_gain_limits_db=[-6, 6],
|
||||
norm_p=2):
|
||||
"""
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
feature_dim : int
|
||||
dimension of features from which kernels, biases and gains are computed
|
||||
|
||||
frame_size : int, optional
|
||||
frame size, defaults to 160
|
||||
|
||||
overlap_size : int, optional
|
||||
overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame, defaults to 40
|
||||
|
||||
use_bias : bool, optional
|
||||
if true, biases will be added to output channels. Defaults to True
|
||||
|
||||
padding : List[int, int], optional
|
||||
left and right padding. Defaults to [(kernel_size - 1) // 2, kernel_size - 1 - (kernel_size - 1) // 2]
|
||||
|
||||
max_lag : int, optional
|
||||
maximal pitch lag, defaults to 256
|
||||
|
||||
have_a0 : bool, optional
|
||||
If true, the filter coefficient a0 will be learned as a positive gain (requires in_channels == out_channels). Otherwise, a0 is set to 0. Defaults to False
|
||||
|
||||
name: str or None, optional
|
||||
specifies a name attribute for the module. If None the name is auto generated as comb_1d_COUNT, where COUNT is an instance counter for LimitedAdaptiveComb1d
|
||||
|
||||
"""
|
||||
|
||||
super(LimitedAdaptiveComb1d, self).__init__()
|
||||
|
||||
self.in_channels = 1
|
||||
self.out_channels = 1
|
||||
self.feature_dim = feature_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.frame_size = frame_size
|
||||
self.overlap_size = overlap_size
|
||||
self.use_bias = use_bias
|
||||
self.max_lag = max_lag
|
||||
self.limit_db = gain_limit_db
|
||||
self.norm_p = norm_p
|
||||
|
||||
if name is None:
|
||||
self.name = "limited_adaptive_comb1d_" + str(LimitedAdaptiveComb1d.COUNTER)
|
||||
LimitedAdaptiveComb1d.COUNTER += 1
|
||||
else:
|
||||
self.name = name
|
||||
|
||||
# network for generating convolution weights
|
||||
self.conv_kernel = nn.Linear(feature_dim, kernel_size)
|
||||
|
||||
if self.use_bias:
|
||||
self.conv_bias = nn.Linear(feature_dim,1)
|
||||
|
||||
# comb filter gain
|
||||
self.filter_gain = nn.Linear(feature_dim, 1)
|
||||
self.log_gain_limit = gain_limit_db * 0.11512925464970229
|
||||
with torch.no_grad():
|
||||
self.filter_gain.bias[:] = max(0.1, 4 + self.log_gain_limit)
|
||||
|
||||
self.global_filter_gain = nn.Linear(feature_dim, 1)
|
||||
log_min, log_max = global_gain_limits_db[0] * 0.11512925464970229, global_gain_limits_db[1] * 0.11512925464970229
|
||||
self.filter_gain_a = (log_max - log_min) / 2
|
||||
self.filter_gain_b = (log_max + log_min) / 2
|
||||
|
||||
if type(padding) == type(None):
|
||||
self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2]
|
||||
else:
|
||||
self.padding = padding
|
||||
|
||||
self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False)
|
||||
|
||||
def forward(self, x, features, lags, debug=False):
|
||||
""" adaptive 1d convolution
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
x : torch.tensor
|
||||
input signal of shape (batch_size, in_channels, num_samples)
|
||||
|
||||
feathres : torch.tensor
|
||||
frame-wise features of shape (batch_size, num_frames, feature_dim)
|
||||
|
||||
lags: torch.LongTensor
|
||||
frame-wise lags for comb-filtering
|
||||
|
||||
"""
|
||||
|
||||
batch_size = x.size(0)
|
||||
num_frames = features.size(1)
|
||||
num_samples = x.size(2)
|
||||
frame_size = self.frame_size
|
||||
overlap_size = self.overlap_size
|
||||
kernel_size = self.kernel_size
|
||||
win1 = torch.flip(self.overlap_win, [0])
|
||||
win2 = self.overlap_win
|
||||
|
||||
if num_samples // self.frame_size != num_frames:
|
||||
raise ValueError('non matching sizes in AdaptiveConv1d.forward')
|
||||
|
||||
conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size))
|
||||
conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=-1, keepdim=True))
|
||||
|
||||
if self.use_bias:
|
||||
conv_biases = self.conv_bias(features).permute(0, 2, 1)
|
||||
|
||||
conv_gains = torch.exp(- torch.relu(self.filter_gain(features).permute(0, 2, 1)) + self.log_gain_limit)
|
||||
# calculate gains
|
||||
global_conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.global_filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b)
|
||||
|
||||
if debug and batch_size == 1:
|
||||
key = self.name + "_gains"
|
||||
write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
||||
key = self.name + "_kernels"
|
||||
write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
||||
key = self.name + "_lags"
|
||||
write_data(key, lags.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
||||
key = self.name + "_global_conv_gains"
|
||||
write_data(key, global_conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
||||
|
||||
|
||||
# frame-wise convolution with overlap-add
|
||||
output_frames = []
|
||||
overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device)
|
||||
x = F.pad(x, self.padding)
|
||||
x = F.pad(x, [self.max_lag, self.overlap_size])
|
||||
|
||||
idx = torch.arange(frame_size + kernel_size - 1 + overlap_size).to(x.device).view(1, 1, -1)
|
||||
idx = torch.repeat_interleave(idx, batch_size, 0)
|
||||
idx = torch.repeat_interleave(idx, self.in_channels, 1)
|
||||
|
||||
|
||||
for i in range(num_frames):
|
||||
|
||||
cidx = idx + i * frame_size + self.max_lag - lags[..., i].view(batch_size, 1, 1)
|
||||
xx = torch.gather(x, -1, cidx).reshape((1, batch_size * self.in_channels, -1))
|
||||
|
||||
new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1)
|
||||
|
||||
|
||||
if self.use_bias:
|
||||
new_chunk = new_chunk + conv_biases[:, :, i : i + 1]
|
||||
|
||||
offset = self.max_lag + self.padding[0]
|
||||
new_chunk = global_conv_gains[:, :, i : i + 1] * (new_chunk * conv_gains[:, :, i : i + 1] + x[..., offset + i * frame_size : offset + (i + 1) * frame_size + overlap_size])
|
||||
|
||||
# overlapping part
|
||||
output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2)
|
||||
|
||||
# non-overlapping part
|
||||
output_frames.append(new_chunk[:, :, overlap_size : frame_size])
|
||||
|
||||
# mem for next frame
|
||||
overlap_mem = new_chunk[:, :, frame_size :]
|
||||
|
||||
# concatenate chunks
|
||||
output = torch.cat(output_frames, dim=-1)
|
||||
|
||||
return output
|
||||
|
||||
def flop_count(self, rate):
|
||||
frame_rate = rate / self.frame_size
|
||||
overlap = self.overlap_size
|
||||
overhead = overlap / self.frame_size
|
||||
|
||||
count = 0
|
||||
|
||||
# kernel computation and filtering
|
||||
count += 2 * (frame_rate * self.feature_dim * self.kernel_size)
|
||||
count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
|
||||
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
|
||||
|
||||
# bias computation
|
||||
if self.use_bias:
|
||||
count += 2 * (frame_rate * self.feature_dim) + rate * (1 + overhead)
|
||||
|
||||
# a0 computation
|
||||
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
|
||||
|
||||
# windowing
|
||||
count += overlap * frame_rate * 3 * self.out_channels
|
||||
|
||||
return count
|
222
dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py
Normal file
222
dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py
Normal file
|
@ -0,0 +1,222 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils.endoscopy import write_data
|
||||
|
||||
class LimitedAdaptiveConv1d(nn.Module):
|
||||
COUNTER = 1
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
feature_dim,
|
||||
frame_size=160,
|
||||
overlap_size=40,
|
||||
use_bias=True,
|
||||
padding=None,
|
||||
name=None,
|
||||
gain_limits_db=[-6, 6],
|
||||
shape_gain_db=0,
|
||||
norm_p=2):
|
||||
"""
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
in_channels : int
|
||||
number of input channels
|
||||
|
||||
out_channels : int
|
||||
number of output channels
|
||||
|
||||
feature_dim : int
|
||||
dimension of features from which kernels, biases and gains are computed
|
||||
|
||||
frame_size : int
|
||||
frame size
|
||||
|
||||
overlap_size : int
|
||||
overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame
|
||||
|
||||
use_bias : bool
|
||||
if true, biases will be added to output channels
|
||||
|
||||
|
||||
padding : List[int, int]
|
||||
|
||||
"""
|
||||
|
||||
super(LimitedAdaptiveConv1d, self).__init__()
|
||||
|
||||
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.feature_dim = feature_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.frame_size = frame_size
|
||||
self.overlap_size = overlap_size
|
||||
self.use_bias = use_bias
|
||||
self.gain_limits_db = gain_limits_db
|
||||
self.shape_gain_db = shape_gain_db
|
||||
self.norm_p = norm_p
|
||||
|
||||
if name is None:
|
||||
self.name = "limited_adaptive_conv1d_" + str(LimitedAdaptiveConv1d.COUNTER)
|
||||
LimitedAdaptiveConv1d.COUNTER += 1
|
||||
else:
|
||||
self.name = name
|
||||
|
||||
# network for generating convolution weights
|
||||
self.conv_kernel = nn.Linear(feature_dim, in_channels * out_channels * kernel_size)
|
||||
|
||||
if self.use_bias:
|
||||
self.conv_bias = nn.Linear(feature_dim, out_channels)
|
||||
|
||||
self.shape_gain = min(1, 10**(shape_gain_db / 20))
|
||||
|
||||
self.filter_gain = nn.Linear(feature_dim, out_channels)
|
||||
log_min, log_max = gain_limits_db[0] * 0.11512925464970229, gain_limits_db[1] * 0.11512925464970229
|
||||
self.filter_gain_a = (log_max - log_min) / 2
|
||||
self.filter_gain_b = (log_max + log_min) / 2
|
||||
|
||||
if type(padding) == type(None):
|
||||
self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2]
|
||||
else:
|
||||
self.padding = padding
|
||||
|
||||
self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False)
|
||||
|
||||
|
||||
def flop_count(self, rate):
|
||||
frame_rate = rate / self.frame_size
|
||||
overlap = self.overlap_size
|
||||
overhead = overlap / self.frame_size
|
||||
|
||||
count = 0
|
||||
|
||||
# kernel computation and filtering
|
||||
count += 2 * (frame_rate * self.feature_dim * self.kernel_size)
|
||||
count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
|
||||
|
||||
# bias computation
|
||||
if self.use_bias:
|
||||
count += 2 * (frame_rate * self.feature_dim) + rate * (1 + overhead)
|
||||
|
||||
# gain computation
|
||||
|
||||
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
|
||||
|
||||
# windowing
|
||||
count += 3 * overlap * frame_rate * self.out_channels
|
||||
|
||||
return count
|
||||
|
||||
def forward(self, x, features, debug=False):
|
||||
""" adaptive 1d convolution
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
x : torch.tensor
|
||||
input signal of shape (batch_size, in_channels, num_samples)
|
||||
|
||||
feathres : torch.tensor
|
||||
frame-wise features of shape (batch_size, num_frames, feature_dim)
|
||||
|
||||
"""
|
||||
|
||||
batch_size = x.size(0)
|
||||
num_frames = features.size(1)
|
||||
num_samples = x.size(2)
|
||||
frame_size = self.frame_size
|
||||
overlap_size = self.overlap_size
|
||||
kernel_size = self.kernel_size
|
||||
win1 = torch.flip(self.overlap_win, [0])
|
||||
win2 = self.overlap_win
|
||||
|
||||
if num_samples // self.frame_size != num_frames:
|
||||
raise ValueError('non matching sizes in AdaptiveConv1d.forward')
|
||||
|
||||
conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size))
|
||||
|
||||
# normalize kernels (TODO: switch to L1 and normalize over kernel and input channel dimension)
|
||||
conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=[-2, -1], keepdim=True))
|
||||
|
||||
# limit shape
|
||||
id_kernels = torch.zeros_like(conv_kernels)
|
||||
id_kernels[..., self.padding[1]] = 1
|
||||
|
||||
conv_kernels = self.shape_gain * conv_kernels + (1 - self.shape_gain) * id_kernels
|
||||
|
||||
if self.use_bias:
|
||||
conv_biases = self.conv_bias(features).permute(0, 2, 1)
|
||||
|
||||
# calculate gains
|
||||
conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b)
|
||||
if debug and batch_size == 1:
|
||||
key = self.name + "_gains"
|
||||
write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
||||
key = self.name + "_kernels"
|
||||
write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
||||
|
||||
|
||||
# frame-wise convolution with overlap-add
|
||||
output_frames = []
|
||||
overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device)
|
||||
x = F.pad(x, self.padding)
|
||||
x = F.pad(x, [0, self.overlap_size])
|
||||
|
||||
for i in range(num_frames):
|
||||
xx = x[:, :, i * frame_size : (i + 1) * frame_size + kernel_size - 1 + overlap_size].reshape((1, batch_size * self.in_channels, -1))
|
||||
new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1)
|
||||
|
||||
if self.use_bias:
|
||||
new_chunk = new_chunk + conv_biases[:, :, i : i + 1]
|
||||
|
||||
new_chunk = new_chunk * conv_gains[:, :, i : i + 1]
|
||||
|
||||
# overlapping part
|
||||
output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2)
|
||||
|
||||
# non-overlapping part
|
||||
output_frames.append(new_chunk[:, :, overlap_size : frame_size])
|
||||
|
||||
# mem for next frame
|
||||
overlap_mem = new_chunk[:, :, frame_size :]
|
||||
|
||||
# concatenate chunks
|
||||
output = torch.cat(output_frames, dim=-1)
|
||||
|
||||
return output
|
84
dnn/torch/osce/utils/layers/pitch_auto_correlator.py
Normal file
84
dnn/torch/osce/utils/layers/pitch_auto_correlator.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class PitchAutoCorrelator(nn.Module):
|
||||
def __init__(self,
|
||||
frame_size=80,
|
||||
pitch_min=32,
|
||||
pitch_max=300,
|
||||
radius=2):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.frame_size = frame_size
|
||||
self.pitch_min = pitch_min
|
||||
self.pitch_max = pitch_max
|
||||
self.radius = radius
|
||||
|
||||
|
||||
def forward(self, x, periods):
|
||||
# x of shape (batch_size, channels, num_samples)
|
||||
# periods of shape (batch_size, num_frames)
|
||||
|
||||
num_frames = periods.size(1)
|
||||
batch_size = periods.size(0)
|
||||
num_samples = self.frame_size * num_frames
|
||||
channels = x.size(1)
|
||||
|
||||
assert num_samples == x.size(-1)
|
||||
|
||||
range = torch.arange(-self.radius, self.radius + 1, device=x.device)
|
||||
idx = torch.arange(self.frame_size * num_frames, device=x.device)
|
||||
p_up = torch.repeat_interleave(periods, self.frame_size, 1)
|
||||
lookup = idx + self.pitch_max - p_up
|
||||
lookup = lookup.unsqueeze(-1) + range
|
||||
lookup = lookup.unsqueeze(1)
|
||||
|
||||
# padding
|
||||
x_pad = F.pad(x, [self.pitch_max, 0])
|
||||
x_ext = torch.repeat_interleave(x_pad.unsqueeze(-1), 2 * self.radius + 1, -1)
|
||||
|
||||
# framing
|
||||
x_select = torch.gather(x_ext, 2, lookup)
|
||||
x_frames = x_pad[..., self.pitch_max : ].reshape(batch_size, channels, num_frames, self.frame_size, 1)
|
||||
lag_frames = x_select.reshape(batch_size, 1, num_frames, self.frame_size, -1)
|
||||
|
||||
# calculate auto-correlation
|
||||
dotp = torch.sum(x_frames * lag_frames, dim=-2)
|
||||
frame_nrg = torch.sum(x_frames * x_frames, dim=-2)
|
||||
lag_frame_nrg = torch.sum(lag_frames * lag_frames, dim=-2)
|
||||
|
||||
acorr = dotp / torch.sqrt(frame_nrg * lag_frame_nrg + 1e-9)
|
||||
|
||||
return acorr
|
42
dnn/torch/osce/utils/misc.py
Normal file
42
dnn/torch/osce/utils/misc.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
def count_parameters(model, verbose=False):
|
||||
total = 0
|
||||
for name, p in model.named_parameters():
|
||||
count = torch.ones_like(p).sum().item()
|
||||
|
||||
if verbose:
|
||||
print(f"{name}: {count} parameters")
|
||||
|
||||
total += count
|
||||
|
||||
return total
|
121
dnn/torch/osce/utils/pitch.py
Normal file
121
dnn/torch/osce/utils/pitch.py
Normal file
|
@ -0,0 +1,121 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
def hangover(lags, num_frames=10):
|
||||
lags = lags.copy()
|
||||
count = 0
|
||||
last_lag = 0
|
||||
|
||||
for i in range(len(lags)):
|
||||
lag = lags[i]
|
||||
|
||||
if lag == 0:
|
||||
if count < num_frames:
|
||||
lags[i] = last_lag
|
||||
count += 1
|
||||
else:
|
||||
count = 0
|
||||
|
||||
return lags
|
||||
|
||||
|
||||
def smooth_pitch_lags(lags, d=2):
|
||||
|
||||
assert d < 4
|
||||
|
||||
num_silk_frames = len(lags) // 4
|
||||
|
||||
smoothed_lags = lags.copy()
|
||||
|
||||
tmp = np.arange(1, d+1)
|
||||
kernel = np.concatenate((tmp, [d+1], tmp[::-1]), dtype=np.float32)
|
||||
kernel = kernel / np.sum(kernel)
|
||||
|
||||
last = lags[0:d][::-1]
|
||||
for i in range(num_silk_frames):
|
||||
frame = lags[i * 4: (i+1) * 4]
|
||||
|
||||
if np.max(np.abs(frame)) == 0:
|
||||
last = frame[4-d:]
|
||||
continue
|
||||
|
||||
if i == num_silk_frames - 1:
|
||||
next = frame[4-d:][::-1]
|
||||
else:
|
||||
next = lags[(i+1) * 4 : (i+1) * 4 + d]
|
||||
|
||||
if np.max(np.abs(next)) == 0:
|
||||
next = frame[4-d:][::-1]
|
||||
|
||||
if np.max(np.abs(last)) == 0:
|
||||
last = frame[0:d][::-1]
|
||||
|
||||
smoothed_frame = np.convolve(np.concatenate((last, frame, next), dtype=np.float32), kernel, mode='valid')
|
||||
|
||||
smoothed_lags[i * 4: (i+1) * 4] = np.round(smoothed_frame)
|
||||
|
||||
last = frame[4-d:]
|
||||
|
||||
return smoothed_lags
|
||||
|
||||
def calculate_acorr_window(x, frame_size, lags, history=None, max_lag=300, radius=2, add_double_lag_acorr=False, no_pitch_threshold=32):
|
||||
eps = 1e-9
|
||||
|
||||
lag_multiplier = 2 if add_double_lag_acorr else 1
|
||||
|
||||
if history is None:
|
||||
history = np.zeros(lag_multiplier * max_lag + radius, dtype=x.dtype)
|
||||
|
||||
offset = len(history)
|
||||
|
||||
assert offset >= max_lag + radius
|
||||
assert len(x) % frame_size == 0
|
||||
|
||||
num_frames = len(x) // frame_size
|
||||
lags = lags.copy()
|
||||
|
||||
x_ext = np.concatenate((history, x), dtype=x.dtype)
|
||||
|
||||
d = radius
|
||||
num_acorrs = 2 * d + 1
|
||||
acorrs = np.zeros((num_frames, lag_multiplier * num_acorrs), dtype=x.dtype)
|
||||
|
||||
for idx in range(num_frames):
|
||||
lag = lags[idx].item()
|
||||
frame = x_ext[offset + idx * frame_size : offset + (idx + 1) * frame_size]
|
||||
|
||||
for k in range(lag_multiplier):
|
||||
lag1 = (k + 1) * lag if lag >= no_pitch_threshold else lag
|
||||
for j in range(num_acorrs):
|
||||
past = x_ext[offset + idx * frame_size - lag1 + j - d : offset + (idx + 1) * frame_size - lag1 + j - d]
|
||||
acorrs[idx, j + k * num_acorrs] = np.dot(frame, past) / np.sqrt(np.dot(frame, frame) * np.dot(past, past) + eps)
|
||||
|
||||
return acorrs, lags
|
151
dnn/torch/osce/utils/silk_features.py
Normal file
151
dnn/torch/osce/utils/silk_features.py
Normal file
|
@ -0,0 +1,151 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import scipy
|
||||
|
||||
from utils.pitch import hangover, calculate_acorr_window
|
||||
from utils.spec import create_filter_bank, cepstrum, log_spectrum, log_spectrum_from_lpc
|
||||
|
||||
def spec_from_lpc(a, n_fft=128, eps=1e-9):
|
||||
order = a.shape[-1]
|
||||
assert order + 1 < n_fft
|
||||
|
||||
x = np.zeros((*a.shape[:-1], n_fft ))
|
||||
x[..., 0] = 1
|
||||
x[..., 1:1 + order] = -a
|
||||
|
||||
X = np.fft.fft(x, axis=-1)
|
||||
X = np.abs(X[..., :n_fft//2 + 1]) ** 2
|
||||
|
||||
S = 1 / (X + eps)
|
||||
|
||||
return S
|
||||
|
||||
def silk_feature_factory(no_pitch_value=256,
|
||||
acorr_radius=2,
|
||||
pitch_hangover=8,
|
||||
num_bands_clean_spec=64,
|
||||
num_bands_noisy_spec=18,
|
||||
noisy_spec_scale='opus',
|
||||
noisy_apply_dct=True,
|
||||
add_offset=False,
|
||||
add_double_lag_acorr=False
|
||||
):
|
||||
|
||||
w = scipy.signal.windows.cosine(320)
|
||||
fb_clean_spec = create_filter_bank(num_bands_clean_spec, 320, scale='erb', round_center_bins=True, normalize=True)
|
||||
fb_noisy_spec = create_filter_bank(num_bands_noisy_spec, 320, scale=noisy_spec_scale, round_center_bins=True, normalize=True)
|
||||
|
||||
def create_features(noisy, noisy_history, lpcs, gains, ltps, periods, offsets):
|
||||
|
||||
periods = periods.copy()
|
||||
|
||||
if pitch_hangover > 0:
|
||||
periods = hangover(periods, num_frames=pitch_hangover)
|
||||
|
||||
periods[periods == 0] = no_pitch_value
|
||||
|
||||
clean_spectrum = 0.3 * log_spectrum_from_lpc(lpcs, fb=fb_clean_spec, n_fft=320)
|
||||
|
||||
if noisy_apply_dct:
|
||||
noisy_cepstrum = np.repeat(
|
||||
cepstrum(np.concatenate((noisy_history[-160:], noisy), dtype=np.float32), 320, fb_noisy_spec, w), 2, 0)
|
||||
else:
|
||||
noisy_cepstrum = np.repeat(
|
||||
log_spectrum(np.concatenate((noisy_history[-160:], noisy), dtype=np.float32), 320, fb_noisy_spec, w), 2, 0)
|
||||
|
||||
log_gains = np.log(gains + 1e-9).reshape(-1, 1)
|
||||
|
||||
acorr, _ = calculate_acorr_window(noisy, 80, periods, noisy_history, radius=acorr_radius, add_double_lag_acorr=add_double_lag_acorr)
|
||||
|
||||
if add_offset:
|
||||
features = np.concatenate((clean_spectrum, noisy_cepstrum, acorr, ltps, log_gains, offsets.reshape(-1, 1)), axis=-1, dtype=np.float32)
|
||||
else:
|
||||
features = np.concatenate((clean_spectrum, noisy_cepstrum, acorr, ltps, log_gains), axis=-1, dtype=np.float32)
|
||||
|
||||
return features, periods.astype(np.int64)
|
||||
|
||||
return create_features
|
||||
|
||||
|
||||
|
||||
def load_inference_data(path,
|
||||
no_pitch_value=256,
|
||||
skip=92,
|
||||
preemph=0.85,
|
||||
acorr_radius=2,
|
||||
pitch_hangover=8,
|
||||
num_bands_clean_spec=64,
|
||||
num_bands_noisy_spec=18,
|
||||
noisy_spec_scale='opus',
|
||||
noisy_apply_dct=True,
|
||||
add_offset=False,
|
||||
add_double_lag_acorr=False,
|
||||
**kwargs):
|
||||
|
||||
print(f"[load_inference_data]: ignoring keyword arguments {kwargs.keys()}...")
|
||||
|
||||
lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16)
|
||||
ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5)
|
||||
gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
|
||||
periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
|
||||
num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32).astype(np.float32).reshape(-1, 1)
|
||||
num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32).reshape(-1, 1)
|
||||
offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32)
|
||||
|
||||
# load signal, add back delay and pre-emphasize
|
||||
signal = np.fromfile(os.path.join(path, 'noisy.s16'), dtype=np.int16).astype(np.float32) / (2 ** 15)
|
||||
signal = np.concatenate((np.zeros(skip, dtype=np.float32), signal), dtype=np.float32)
|
||||
|
||||
create_features = silk_feature_factory(no_pitch_value, acorr_radius, pitch_hangover, num_bands_clean_spec, num_bands_noisy_spec, noisy_spec_scale, noisy_apply_dct, add_offset, add_double_lag_acorr)
|
||||
|
||||
num_frames = min((len(signal) // 320) * 4, len(lpcs))
|
||||
signal = signal[: num_frames * 80]
|
||||
lpcs = lpcs[: num_frames]
|
||||
ltps = ltps[: num_frames]
|
||||
gains = gains[: num_frames]
|
||||
periods = periods[: num_frames]
|
||||
num_bits = num_bits[: num_frames // 4]
|
||||
num_bits_smooth = num_bits[: num_frames // 4]
|
||||
offsets = offsets[: num_frames]
|
||||
|
||||
numbits = np.repeat(np.concatenate((num_bits, num_bits_smooth), axis=-1, dtype=np.float32), 4, axis=0)
|
||||
|
||||
features, periods = create_features(signal, np.zeros(350, dtype=signal.dtype), lpcs, gains, ltps, periods, offsets)
|
||||
|
||||
if preemph > 0:
|
||||
signal[1:] -= preemph * signal[:-1]
|
||||
|
||||
return torch.from_numpy(signal), torch.from_numpy(features), torch.from_numpy(periods), torch.from_numpy(numbits)
|
194
dnn/torch/osce/utils/spec.py
Normal file
194
dnn/torch/osce/utils/spec.py
Normal file
|
@ -0,0 +1,194 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import math as m
|
||||
import numpy as np
|
||||
import scipy
|
||||
|
||||
def erb(f):
|
||||
return 24.7 * (4.37 * f + 1)
|
||||
|
||||
def inv_erb(e):
|
||||
return (e / 24.7 - 1) / 4.37
|
||||
|
||||
def bark(f):
|
||||
return 6 * m.asinh(f/600)
|
||||
|
||||
def inv_bark(b):
|
||||
return 600 * m.sinh(b / 6)
|
||||
|
||||
|
||||
scale_dict = {
|
||||
'bark': [bark, inv_bark],
|
||||
'erb': [erb, inv_erb]
|
||||
}
|
||||
|
||||
def create_filter_bank(num_bands, n_fft=320, fs=16000, scale='bark', round_center_bins=False, return_upper=False, normalize=False):
|
||||
|
||||
f0 = 0
|
||||
num_bins = n_fft // 2 + 1
|
||||
f1 = fs / n_fft * (num_bins - 1)
|
||||
fstep = fs / n_fft
|
||||
|
||||
if scale == 'opus':
|
||||
bins_5ms = [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40]
|
||||
fac = 1000 * n_fft / fs / 5
|
||||
if num_bands != 18:
|
||||
print("warning: requested Opus filter bank with num_bands != 18. Adjusting num_bands.")
|
||||
num_bands = 18
|
||||
center_bins = np.array([fac * bin for bin in bins_5ms])
|
||||
else:
|
||||
to_scale, from_scale = scale_dict[scale]
|
||||
|
||||
s0 = to_scale(f0)
|
||||
s1 = to_scale(f1)
|
||||
|
||||
center_freqs = np.array([f0] + [from_scale(s0 + i * (s1 - s0) / (num_bands)) for i in range(1, num_bands - 1)] + [f1])
|
||||
center_bins = (center_freqs - f0) / fstep
|
||||
|
||||
if round_center_bins:
|
||||
center_bins = np.round(center_bins)
|
||||
|
||||
filter_bank = np.zeros((num_bands, num_bins))
|
||||
|
||||
band = 0
|
||||
for bin in range(num_bins):
|
||||
# update band index
|
||||
if bin > center_bins[band + 1]:
|
||||
band += 1
|
||||
|
||||
# calculate filter coefficients
|
||||
frac = (center_bins[band + 1] - bin) / (center_bins[band + 1] - center_bins[band])
|
||||
filter_bank[band][bin] = frac
|
||||
filter_bank[band + 1][bin] = 1 - frac
|
||||
|
||||
if return_upper:
|
||||
extend = n_fft - num_bins
|
||||
filter_bank = np.concatenate((filter_bank, np.fliplr(filter_bank[:, 1:extend+1])), axis=1)
|
||||
|
||||
if normalize:
|
||||
filter_bank = filter_bank / np.sum(filter_bank, axis=1).reshape(-1, 1)
|
||||
|
||||
return filter_bank
|
||||
|
||||
|
||||
def compressed_log_spec(pspec):
|
||||
|
||||
lpspec = np.zeros_like(pspec)
|
||||
num_bands = pspec.shape[-1]
|
||||
|
||||
log_max = -2
|
||||
follow = -2
|
||||
|
||||
for i in range(num_bands):
|
||||
tmp = np.log10(pspec[i] + 1e-9)
|
||||
tmp = max(log_max, max(follow - 2.5, tmp))
|
||||
lpspec[i] = tmp
|
||||
log_max = max(log_max, tmp)
|
||||
follow = max(follow - 2.5, tmp)
|
||||
|
||||
return lpspec
|
||||
|
||||
def log_spectrum_from_lpc(a, fb=None, n_fft=320, eps=1e-9, gamma=1, compress=False, power=1):
|
||||
""" calculates cepstrum from SILK lpcs """
|
||||
order = a.shape[-1]
|
||||
assert order + 1 < n_fft
|
||||
|
||||
a = a * (gamma ** (1 + np.arange(order))).astype(np.float32)
|
||||
|
||||
x = np.zeros((*a.shape[:-1], n_fft ))
|
||||
x[..., 0] = 1
|
||||
x[..., 1:1 + order] = -a
|
||||
|
||||
X = np.fft.fft(x, axis=-1)
|
||||
X = np.abs(X[..., :n_fft//2 + 1]) ** power
|
||||
|
||||
S = 1 / (X + eps)
|
||||
|
||||
if fb is None:
|
||||
Sf = S
|
||||
else:
|
||||
Sf = np.matmul(S, fb.T)
|
||||
|
||||
if compress:
|
||||
Sf = np.apply_along_axis(compressed_log_spec, -1, Sf)
|
||||
else:
|
||||
Sf = np.log(Sf + eps)
|
||||
|
||||
return Sf
|
||||
|
||||
def cepstrum_from_lpc(a, fb=None, n_fft=320, eps=1e-9, gamma=1, compress=False):
|
||||
""" calculates cepstrum from SILK lpcs """
|
||||
|
||||
Sf = log_spectrum_from_lpc(a, fb, n_fft, eps, gamma, compress)
|
||||
|
||||
cepstrum = scipy.fftpack.dct(Sf, 2, norm='ortho')
|
||||
|
||||
return cepstrum
|
||||
|
||||
|
||||
|
||||
def log_spectrum(x, frame_size, fb=None, window=None, power=1):
|
||||
""" calculate cepstrum on 50% overlapping frames """
|
||||
|
||||
assert(2*len(x)) % frame_size == 0
|
||||
assert frame_size % 2 == 0
|
||||
|
||||
n = len(x)
|
||||
num_even = n // frame_size
|
||||
num_odd = (n - frame_size // 2) // frame_size
|
||||
num_bins = frame_size // 2 + 1
|
||||
|
||||
x_even = x[:num_even * frame_size].reshape(-1, frame_size)
|
||||
x_odd = x[frame_size//2 : frame_size//2 + frame_size * num_odd].reshape(-1, frame_size)
|
||||
|
||||
x_unfold = np.empty((x_even.size + x_odd.size), dtype=x.dtype).reshape((-1, frame_size))
|
||||
x_unfold[::2, :] = x_even
|
||||
x_unfold[1::2, :] = x_odd
|
||||
|
||||
if window is not None:
|
||||
x_unfold *= window.reshape(1, -1)
|
||||
|
||||
X = np.abs(np.fft.fft(x_unfold, n=frame_size, axis=-1))[:, :num_bins] ** power
|
||||
|
||||
if fb is not None:
|
||||
X = np.matmul(X, fb.T)
|
||||
|
||||
|
||||
return np.log(X + 1e-9)
|
||||
|
||||
|
||||
def cepstrum(x, frame_size, fb=None, window=None):
|
||||
""" calculate cepstrum on 50% overlapping frames """
|
||||
|
||||
X = log_spectrum(x, frame_size, fb, window)
|
||||
|
||||
cepstrum = scipy.fftpack.dct(X, 2, norm='ortho')
|
||||
|
||||
return cepstrum
|
92
dnn/torch/osce/utils/templates.py
Normal file
92
dnn/torch/osce/utils/templates.py
Normal file
|
@ -0,0 +1,92 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
|
||||
setup_dict = dict()
|
||||
|
||||
lace_setup = {
|
||||
'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training',
|
||||
'validation_dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/validation',
|
||||
'model': {
|
||||
'name': 'lace',
|
||||
'args': [],
|
||||
'kwargs': {
|
||||
'comb_gain_limit_db': 10,
|
||||
'cond_dim': 128,
|
||||
'conv_gain_limits_db': [-12, 12],
|
||||
'global_gain_limits_db': [-6, 6],
|
||||
'hidden_feature_dim': 96,
|
||||
'kernel_size': 15,
|
||||
'num_features': 93,
|
||||
'numbits_embedding_dim': 8,
|
||||
'numbits_range': [50, 650],
|
||||
'partial_lookahead': True,
|
||||
'pitch_embedding_dim': 64,
|
||||
'pitch_max': 300,
|
||||
'preemph': 0.85,
|
||||
'skip': 91
|
||||
}
|
||||
},
|
||||
'data': {
|
||||
'frames_per_sample': 100,
|
||||
'no_pitch_value': 7,
|
||||
'preemph': 0.85,
|
||||
'skip': 91,
|
||||
'pitch_hangover': 8,
|
||||
'acorr_radius': 2,
|
||||
'num_bands_clean_spec': 64,
|
||||
'num_bands_noisy_spec': 18,
|
||||
'noisy_spec_scale': 'opus',
|
||||
'pitch_hangover': 8,
|
||||
},
|
||||
'training': {
|
||||
'batch_size': 256,
|
||||
'lr': 5.e-4,
|
||||
'lr_decay_factor': 2.5e-5,
|
||||
'epochs': 50,
|
||||
'frames_per_sample': 50,
|
||||
'loss': {
|
||||
'w_l1': 0,
|
||||
'w_lm': 0,
|
||||
'w_logmel': 0,
|
||||
'w_sc': 0,
|
||||
'w_wsc': 0,
|
||||
'w_xcorr': 0,
|
||||
'w_sxcorr': 1,
|
||||
'w_l2': 10,
|
||||
'w_slm': 2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
setup_dict = {
|
||||
'lace': lace_setup,
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue