diff --git a/dnn/torch/osce/README.md b/dnn/torch/osce/README.md new file mode 100644 index 00000000..1f940113 --- /dev/null +++ b/dnn/torch/osce/README.md @@ -0,0 +1,4 @@ +# Opus Speech Coding Enhancement + +This folder hosts models for enhancing SILK. See related Opus repo https://gitlab.xiph.org/xiph/opus/-/tree/exp-neural-silk-enhancement +for feature generation. \ No newline at end of file diff --git a/dnn/torch/osce/data/__init__.py b/dnn/torch/osce/data/__init__.py new file mode 100644 index 00000000..9f7ea183 --- /dev/null +++ b/dnn/torch/osce/data/__init__.py @@ -0,0 +1,30 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +from .silk_enhancement_set import SilkEnhancementSet \ No newline at end of file diff --git a/dnn/torch/osce/data/silk_enhancement_set.py b/dnn/torch/osce/data/silk_enhancement_set.py new file mode 100644 index 00000000..186333e9 --- /dev/null +++ b/dnn/torch/osce/data/silk_enhancement_set.py @@ -0,0 +1,140 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os + +from torch.utils.data import Dataset +import numpy as np + +from utils.silk_features import silk_feature_factory +from utils.pitch import hangover, calculate_acorr_window + + +class SilkEnhancementSet(Dataset): + def __init__(self, + path, + frames_per_sample=100, + no_pitch_value=256, + preemph=0.85, + skip=91, + acorr_radius=2, + pitch_hangover=8, + num_bands_clean_spec=64, + num_bands_noisy_spec=18, + noisy_spec_scale='opus', + noisy_apply_dct=True, + add_offset=False, + add_double_lag_acorr=False + ): + + assert frames_per_sample % 4 == 0 + + self.frame_size = 80 + self.frames_per_sample = frames_per_sample + self.no_pitch_value = no_pitch_value + self.preemph = preemph + self.skip = skip + self.acorr_radius = acorr_radius + self.pitch_hangover = pitch_hangover + self.num_bands_clean_spec = num_bands_clean_spec + self.num_bands_noisy_spec = num_bands_noisy_spec + self.noisy_spec_scale = noisy_spec_scale + self.add_double_lag_acorr = add_double_lag_acorr + + self.lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16) + self.ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5) + self.periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16) + self.gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32) + self.num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32) + self.num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32) + self.offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32) + + self.clean_signal = np.fromfile(os.path.join(path, 'clean_hp.s16'), dtype=np.int16) + self.coded_signal = np.fromfile(os.path.join(path, 'coded.s16'), dtype=np.int16) + + self.create_features = silk_feature_factory(no_pitch_value, + acorr_radius, + pitch_hangover, + num_bands_clean_spec, + num_bands_noisy_spec, + noisy_spec_scale, + noisy_apply_dct, + add_offset, + add_double_lag_acorr) + + self.history_len = 700 if add_double_lag_acorr else 350 + # discard some frames to have enough signal history + self.skip_frames = 4 * ((skip + self.history_len + 319) // 320 + 2) + + num_frames = self.clean_signal.shape[0] // 80 - self.skip_frames + + self.len = num_frames // frames_per_sample + + def __len__(self): + return self.len + + def __getitem__(self, index): + + frame_start = self.frames_per_sample * index + self.skip_frames + frame_stop = frame_start + self.frames_per_sample + + signal_start = frame_start * self.frame_size - self.skip + signal_stop = frame_stop * self.frame_size - self.skip + + clean_signal = self.clean_signal[signal_start : signal_stop].astype(np.float32) / 2**15 + coded_signal = self.coded_signal[signal_start : signal_stop].astype(np.float32) / 2**15 + + coded_signal_history = self.coded_signal[signal_start - self.history_len : signal_start].astype(np.float32) / 2**15 + + features, periods = self.create_features( + coded_signal, + coded_signal_history, + self.lpcs[frame_start : frame_stop], + self.gains[frame_start : frame_stop], + self.ltps[frame_start : frame_stop], + self.periods[frame_start : frame_stop], + self.offsets[frame_start : frame_stop] + ) + + if self.preemph > 0: + clean_signal[1:] -= self.preemph * clean_signal[: -1] + coded_signal[1:] -= self.preemph * coded_signal[: -1] + + num_bits = np.repeat(self.num_bits[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1) + num_bits_smooth = np.repeat(self.num_bits_smooth[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1) + + numbits = np.concatenate((num_bits, num_bits_smooth), axis=-1) + + return { + 'features' : features, + 'periods' : periods.astype(np.int64), + 'target' : clean_signal.astype(np.float32), + 'signals' : coded_signal.reshape(-1, 1).astype(np.float32), + 'numbits' : numbits.astype(np.float32) + } diff --git a/dnn/torch/osce/engine/engine.py b/dnn/torch/osce/engine/engine.py new file mode 100644 index 00000000..7688e9b4 --- /dev/null +++ b/dnn/torch/osce/engine/engine.py @@ -0,0 +1,101 @@ +import torch +from tqdm import tqdm +import sys + +def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10): + + model.to(device) + model.train() + + running_loss = 0 + previous_running_loss = 0 + + + with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch: + + for i, batch in enumerate(tepoch): + + # set gradients to zero + optimizer.zero_grad() + + + # push batch to device + for key in batch: + batch[key] = batch[key].to(device) + + target = batch['target'] + + # calculate model output + output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits']) + + # calculate loss + if isinstance(output, list): + loss = torch.zeros(1, device=device) + for y in output: + loss = loss + criterion(target, y.squeeze(1)) + loss = loss / len(output) + else: + loss = criterion(target, output.squeeze(1)) + + # calculate gradients + loss.backward() + + # update weights + optimizer.step() + + # update learning rate + scheduler.step() + + # update running loss + running_loss += float(loss.cpu()) + + # update status bar + if i % log_interval == 0: + tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}") + previous_running_loss = running_loss + + + running_loss /= len(dataloader) + + return running_loss + +def evaluate(model, criterion, dataloader, device, log_interval=10): + + model.to(device) + model.eval() + + running_loss = 0 + previous_running_loss = 0 + + + with torch.no_grad(): + with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch: + + for i, batch in enumerate(tepoch): + + + + # push batch to device + for key in batch: + batch[key] = batch[key].to(device) + + target = batch['target'] + + # calculate model output + output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits']) + + # calculate loss + loss = criterion(target, output.squeeze(1)) + + # update running loss + running_loss += float(loss.cpu()) + + # update status bar + if i % log_interval == 0: + tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}") + previous_running_loss = running_loss + + + running_loss /= len(dataloader) + + return running_loss \ No newline at end of file diff --git a/dnn/torch/osce/losses/stft_loss.py b/dnn/torch/osce/losses/stft_loss.py new file mode 100644 index 00000000..4c164cb6 --- /dev/null +++ b/dnn/torch/osce/losses/stft_loss.py @@ -0,0 +1,277 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +"""STFT-based Loss modules.""" + +import torch +import torch.nn.functional as F +from torch import nn +import numpy as np +import torchaudio + + +def get_window(win_name, win_length, *args, **kwargs): + window_dict = { + 'bartlett_window' : torch.bartlett_window, + 'blackman_window' : torch.blackman_window, + 'hamming_window' : torch.hamming_window, + 'hann_window' : torch.hann_window, + 'kaiser_window' : torch.kaiser_window + } + + if not win_name in window_dict: + raise ValueError() + + return window_dict[win_name](win_length, *args, **kwargs) + + +def stft(x, fft_size, hop_size, win_length, window): + """Perform STFT and convert to magnitude spectrogram. + Args: + x (Tensor): Input signal tensor (B, T). + fft_size (int): FFT size. + hop_size (int): Hop size. + win_length (int): Window length. + window (str): Window function type. + Returns: + Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). + """ + + win = get_window(window, win_length).to(x.device) + x_stft = torch.stft(x, fft_size, hop_size, win_length, win, return_complex=True) + + + return torch.clamp(torch.abs(x_stft), min=1e-7) + +def spectral_convergence_loss(Y_true, Y_pred): + dims=list(range(1, len(Y_pred.shape))) + return torch.mean(torch.norm(torch.abs(Y_true) - torch.abs(Y_pred), p="fro", dim=dims) / (torch.norm(Y_pred, p="fro", dim=dims) + 1e-6)) + + +def log_magnitude_loss(Y_true, Y_pred): + Y_true_log_abs = torch.log(torch.abs(Y_true) + 1e-15) + Y_pred_log_abs = torch.log(torch.abs(Y_pred) + 1e-15) + + return torch.mean(torch.abs(Y_true_log_abs - Y_pred_log_abs)) + +def spectral_xcorr_loss(Y_true, Y_pred): + Y_true = Y_true.abs() + Y_pred = Y_pred.abs() + dims=list(range(1, len(Y_pred.shape))) + xcorr = torch.sum(Y_true * Y_pred, dim=dims) / torch.sqrt(torch.sum(Y_true ** 2, dim=dims) * torch.sum(Y_pred ** 2, dim=dims) + 1e-9) + + return 1 - xcorr.mean() + + + +class MRLogMelLoss(nn.Module): + def __init__(self, + fft_sizes=[512, 256, 128, 64], + overlap=0.5, + fs=16000, + n_mels=18 + ): + + self.fft_sizes = fft_sizes + self.overlap = overlap + self.fs = fs + self.n_mels = n_mels + + super().__init__() + + self.mel_specs = [] + for fft_size in fft_sizes: + hop_size = int(round(fft_size * (1 - self.overlap))) + + n_mels = self.n_mels + if fft_size < 128: + n_mels //= 2 + + self.mel_specs.append(torchaudio.transforms.MelSpectrogram(fs, fft_size, hop_length=hop_size, n_mels=n_mels)) + + for i, mel_spec in enumerate(self.mel_specs): + self.add_module(f'mel_spec_{i+1}', mel_spec) + + def forward(self, y_true, y_pred): + + loss = torch.zeros(1, device=y_true.device) + + for mel_spec in self.mel_specs: + Y_true = mel_spec(y_true) + Y_pred = mel_spec(y_pred) + loss = loss + log_magnitude_loss(Y_true, Y_pred) + + loss = loss / len(self.mel_specs) + + return loss + +def create_weight_matrix(num_bins, bins_per_band=10): + m = torch.zeros((num_bins, num_bins), dtype=torch.float32) + + r0 = bins_per_band // 2 + r1 = bins_per_band - r0 + + for i in range(num_bins): + i0 = max(i - r0, 0) + j0 = min(i + r1, num_bins) + + m[i, i0: j0] += 1 + + if i < r0: + m[i, :r0 - i] += 1 + + if i > num_bins - r1: + m[i, num_bins - r1 - i:] += 1 + + return m / bins_per_band + +def weighted_spectral_convergence(Y_true, Y_pred, w): + + # calculate sfm based weights + logY = torch.log(torch.abs(Y_true) + 1e-9) + Y = torch.abs(Y_true) + + avg_logY = torch.matmul(logY.transpose(1, 2), w) + avg_Y = torch.matmul(Y.transpose(1, 2), w) + + sfm = torch.exp(avg_logY) / (avg_Y + 1e-9) + + weight = (torch.relu(1 - sfm) ** .5).transpose(1, 2) + + loss = torch.mean( + torch.mean(weight * torch.abs(torch.abs(Y_true) - torch.abs(Y_pred)), dim=[1, 2]) + / (torch.mean( weight * torch.abs(Y_true), dim=[1, 2]) + 1e-9) + ) + + return loss + +def gen_filterbank(N, Fs=16000): + in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:] + out_freq = (np.arange(N, dtype='float32')/N*Fs/2)[:,None] + #ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73. + ERB_N = 24.7 + .108*in_freq + delta = np.abs(in_freq-out_freq)/ERB_N + center = (delta<.5).astype('float32') + R = -12*center*delta**2 + (1-center)*(3-12*delta) + RE = 10.**(R/10.) + norm = np.sum(RE, axis=1) + RE = RE/norm[:, np.newaxis] + return torch.from_numpy(RE) + +def smooth_log_mag(Y_true, Y_pred, filterbank): + Y_true_smooth = torch.matmul(filterbank, torch.abs(Y_true)) + Y_pred_smooth = torch.matmul(filterbank, torch.abs(Y_pred)) + + loss = torch.abs( + torch.log(Y_true_smooth + 1e-9) - torch.log(Y_pred_smooth + 1e-9) + ) + + loss = loss.mean() + + return loss + +class MRSTFTLoss(nn.Module): + def __init__(self, + fft_sizes=[2048, 1024, 512, 256, 128, 64], + overlap=0.5, + window='hann_window', + fs=16000, + log_mag_weight=1, + sc_weight=0, + wsc_weight=0, + smooth_log_mag_weight=0, + sxcorr_weight=0): + super().__init__() + + self.fft_sizes = fft_sizes + self.overlap = overlap + self.window = window + self.log_mag_weight = log_mag_weight + self.sc_weight = sc_weight + self.wsc_weight = wsc_weight + self.smooth_log_mag_weight = smooth_log_mag_weight + self.sxcorr_weight = sxcorr_weight + self.fs = fs + + # weights for SFM weighted spectral convergence loss + self.wsc_weights = torch.nn.ParameterDict() + for fft_size in fft_sizes: + width = min(11, int(1000 * fft_size / self.fs + .5)) + width += width % 2 + self.wsc_weights[str(fft_size)] = torch.nn.Parameter( + create_weight_matrix(fft_size // 2 + 1, width), + requires_grad=False + ) + + # filterbanks for smooth log magnitude loss + self.filterbanks = torch.nn.ParameterDict() + for fft_size in fft_sizes: + self.filterbanks[str(fft_size)] = torch.nn.Parameter( + gen_filterbank(fft_size//2), + requires_grad=False + ) + + + def __call__(self, y_true, y_pred): + + + lm_loss = torch.zeros(1, device=y_true.device) + sc_loss = torch.zeros(1, device=y_true.device) + wsc_loss = torch.zeros(1, device=y_true.device) + slm_loss = torch.zeros(1, device=y_true.device) + sxcorr_loss = torch.zeros(1, device=y_true.device) + + for fft_size in self.fft_sizes: + hop_size = int(round(fft_size * (1 - self.overlap))) + win_size = fft_size + + Y_true = stft(y_true, fft_size, hop_size, win_size, self.window) + Y_pred = stft(y_pred, fft_size, hop_size, win_size, self.window) + + if self.log_mag_weight > 0: + lm_loss = lm_loss + log_magnitude_loss(Y_true, Y_pred) + + if self.sc_weight > 0: + sc_loss = sc_loss + spectral_convergence_loss(Y_true, Y_pred) + + if self.wsc_weight > 0: + wsc_loss = wsc_loss + weighted_spectral_convergence(Y_true, Y_pred, self.wsc_weights[str(fft_size)]) + + if self.smooth_log_mag_weight > 0: + slm_loss = slm_loss + smooth_log_mag(Y_true, Y_pred, self.filterbanks[str(fft_size)]) + + if self.sxcorr_weight > 0: + sxcorr_loss = sxcorr_loss + spectral_xcorr_loss(Y_true, Y_pred) + + + total_loss = (self.log_mag_weight * lm_loss + self.sc_weight * sc_loss + + self.wsc_weight * wsc_loss + self.smooth_log_mag_weight * slm_loss + + self.sxcorr_weight * sxcorr_loss) / len(self.fft_sizes) + + return total_loss \ No newline at end of file diff --git a/dnn/torch/osce/make_default_setup.py b/dnn/torch/osce/make_default_setup.py new file mode 100644 index 00000000..06add8fa --- /dev/null +++ b/dnn/torch/osce/make_default_setup.py @@ -0,0 +1,56 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import argparse + +import yaml + +from utils.templates import setup_dict + +parser = argparse.ArgumentParser() + +parser.add_argument('name', type=str, help='name of default setup file') +parser.add_argument('--model', choices=['lace'], help='model name', default='lace') +parser.add_argument('--path2dataset', type=str, help='dataset path', default=None) + +args = parser.parse_args() + +setup = setup_dict[args.model] + +# update dataset if given +if type(args.path2dataset) != type(None): + setup['dataset'] = args.path2dataset + +name = args.name +if not name.endswith('.yml'): + name += '.yml' + +if __name__ == '__main__': + with open(name, 'w') as f: + f.write(yaml.dump(setup)) \ No newline at end of file diff --git a/dnn/torch/osce/models/__init__.py b/dnn/torch/osce/models/__init__.py new file mode 100644 index 00000000..c8dfc5d9 --- /dev/null +++ b/dnn/torch/osce/models/__init__.py @@ -0,0 +1,36 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +from .lace import LACE + + + +model_dict = { + 'lace': LACE +} diff --git a/dnn/torch/osce/models/lace.py b/dnn/torch/osce/models/lace.py new file mode 100644 index 00000000..a11dfc41 --- /dev/null +++ b/dnn/torch/osce/models/lace.py @@ -0,0 +1,176 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import torch +from torch import nn +import torch.nn.functional as F + +import numpy as np + +from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d +from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d + +from models.nns_base import NNSBase +from models.silk_feature_net_pl import SilkFeatureNetPL +from models.silk_feature_net import SilkFeatureNet +from .scale_embedding import ScaleEmbedding + +class LACE(NNSBase): + """ Linear-Adaptive Coding Enhancer """ + FRAME_SIZE=80 + + def __init__(self, + num_features=47, + pitch_embedding_dim=64, + cond_dim=256, + pitch_max=257, + kernel_size=15, + preemph=0.85, + skip=91, + comb_gain_limit_db=-6, + global_gain_limits_db=[-6, 6], + conv_gain_limits_db=[-6, 6], + numbits_range=[50, 650], + numbits_embedding_dim=8, + hidden_feature_dim=64, + partial_lookahead=True, + norm_p=2): + + super().__init__(skip=skip, preemph=preemph) + + + self.num_features = num_features + self.cond_dim = cond_dim + self.pitch_max = pitch_max + self.pitch_embedding_dim = pitch_embedding_dim + self.kernel_size = kernel_size + self.preemph = preemph + self.skip = skip + self.numbits_range = numbits_range + self.numbits_embedding_dim = numbits_embedding_dim + self.hidden_feature_dim = hidden_feature_dim + self.partial_lookahead = partial_lookahead + + # pitch embedding + self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim) + + # numbits embedding + self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True) + + # feature net + if partial_lookahead: + self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim) + else: + self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim) + + # comb filters + left_pad = self.kernel_size // 2 + right_pad = self.kernel_size - 1 - left_pad + self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p) + self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p) + + # spectral shaping + self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + + def flop_count(self, rate=16000, verbose=False): + + frame_rate = rate / self.FRAME_SIZE + + # feature net + feature_net_flops = self.feature_net.flop_count(frame_rate) + comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate) + af_flops = self.af1.flop_count(rate) + + if verbose: + print(f"feature net: {feature_net_flops / 1e6} MFLOPS") + print(f"comb filters: {comb_flops / 1e6} MFLOPS") + print(f"adaptive conv: {af_flops / 1e6} MFLOPS") + + return feature_net_flops + comb_flops + af_flops + + def forward(self, x, features, periods, numbits, debug=False): + + periods = periods.squeeze(-1) + pitch_embedding = self.pitch_embedding(periods) + numbits_embedding = self.numbits_embedding(numbits).flatten(2) + + full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1) + cf = self.feature_net(full_features) + + y = self.cf1(x, cf, periods, debug=debug) + + y = self.cf2(y, cf, periods, debug=debug) + + y = self.af1(y, cf, debug=debug) + + return y + + def get_impulse_responses(self, features, periods, numbits): + """ generates impoulse responses on frame centers (input without batch dimension) """ + + num_frames = features.size(0) + batch_size = 32 + max_len = 2 * (self.pitch_max + self.kernel_size) + 10 + + # spread out some pulses + x = np.zeros((batch_size, 1, num_frames * self.FRAME_SIZE)) + for b in range(batch_size): + x[b, :, self.FRAME_SIZE // 2 + b * self.FRAME_SIZE :: batch_size * self.FRAME_SIZE] = 1 + + # prepare input + x = torch.from_numpy(x).float().to(features.device) + features = torch.repeat_interleave(features.unsqueeze(0), batch_size, 0) + periods = torch.repeat_interleave(periods.unsqueeze(0), batch_size, 0) + numbits = torch.repeat_interleave(numbits.unsqueeze(0), batch_size, 0) + + # run network + with torch.no_grad(): + periods = periods.squeeze(-1) + pitch_embedding = self.pitch_embedding(periods) + numbits_embedding = self.numbits_embedding(numbits).flatten(2) + full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1) + cf = self.feature_net(full_features) + y = self.cf1(x, cf, periods, debug=False) + y = self.cf2(y, cf, periods, debug=False) + y = self.af1(y, cf, debug=False) + + # collect responses + y = y.detach().squeeze().cpu().numpy() + cut_frames = (max_len + self.FRAME_SIZE - 1) // self.FRAME_SIZE + num_responses = num_frames - cut_frames + responses = np.zeros((num_responses, max_len)) + + for i in range(num_responses): + b = i % batch_size + start = self.FRAME_SIZE // 2 + i * self.FRAME_SIZE + stop = start + max_len + + responses[i, :] = y[b, start:stop] + + return responses diff --git a/dnn/torch/osce/models/nns_base.py b/dnn/torch/osce/models/nns_base.py new file mode 100644 index 00000000..6e667b96 --- /dev/null +++ b/dnn/torch/osce/models/nns_base.py @@ -0,0 +1,69 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import torch +from torch import nn +import torch.nn.functional as F + +class NNSBase(nn.Module): + + def __init__(self, skip=91, preemph=0.85): + super().__init__() + + self.skip = skip + self.preemph = preemph + + def process(self, sig, features, periods, numbits, debug=False): + + self.eval() + has_numbits = 'numbits' in self.forward.__code__.co_varnames + device = next(iter(self.parameters())).device + with torch.no_grad(): + + # run model + x = sig.view(1, 1, -1).to(device) + f = features.unsqueeze(0).to(device) + p = periods.unsqueeze(0).to(device) + n = numbits.unsqueeze(0).to(device) + + if has_numbits: + y = self.forward(x, f, p, n, debug=debug).squeeze() + else: + y = self.forward(x, f, p, debug=debug).squeeze() + + # deemphasis + if self.preemph > 0: + for i in range(len(y) - 1): + y[i + 1] += self.preemph * y[i] + + # delay compensation + y = torch.cat((y[self.skip:], torch.zeros(self.skip, dtype=y.dtype, device=y.device))) + out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short() + + return out \ No newline at end of file diff --git a/dnn/torch/osce/models/scale_embedding.py b/dnn/torch/osce/models/scale_embedding.py new file mode 100644 index 00000000..58695302 --- /dev/null +++ b/dnn/torch/osce/models/scale_embedding.py @@ -0,0 +1,68 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import math as m +import torch +from torch import nn + + +class ScaleEmbedding(nn.Module): + def __init__(self, + dim, + min_val, + max_val, + logscale=False): + + super().__init__() + + if min_val >= max_val: + raise ValueError('min_val must be smaller than max_val') + + if min_val <= 0 and logscale: + raise ValueError('min_val must be positive when logscale is true') + + self.dim = dim + self.logscale = logscale + self.min_val = min_val + self.max_val = max_val + + if logscale: + self.min_val = m.log(self.min_val) + self.max_val = m.log(self.max_val) + + + self.offset = (self.min_val + self.max_val) / 2 + self.scale_factors = nn.Parameter( + torch.arange(1, dim+1, dtype=torch.float32) * torch.pi / (self.max_val - self.min_val) + ) + + def forward(self, x): + if self.logscale: x = torch.log(x) + x = torch.clip(x, self.min_val, self.max_val) - self.offset + return torch.sin(x.unsqueeze(-1) * self.scale_factors - 0.5) diff --git a/dnn/torch/osce/models/silk_feature_net.py b/dnn/torch/osce/models/silk_feature_net.py new file mode 100644 index 00000000..ed22f52e --- /dev/null +++ b/dnn/torch/osce/models/silk_feature_net.py @@ -0,0 +1,86 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + + +import torch +from torch import nn +import torch.nn.functional as F + +from utils.complexity import _conv1d_flop_count + +class SilkFeatureNet(nn.Module): + + def __init__(self, + feature_dim=47, + num_channels=256, + lookahead=False): + + super(SilkFeatureNet, self).__init__() + + self.feature_dim = feature_dim + self.num_channels = num_channels + self.lookahead = lookahead + + self.conv1 = nn.Conv1d(feature_dim, num_channels, 3) + self.conv2 = nn.Conv1d(num_channels, num_channels, 3, dilation=2) + + self.gru = nn.GRU(num_channels, num_channels, batch_first=True) + + def flop_count(self, rate=200): + count = 0 + for conv in self.conv1, self.conv2: + count += _conv1d_flop_count(conv, rate) + + count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate + + return count + + + def forward(self, features, state=None): + """ features shape: (batch_size, num_frames, feature_dim) """ + + batch_size = features.size(0) + + if state is None: + state = torch.zeros((1, batch_size, self.num_channels), device=features.device) + + + features = features.permute(0, 2, 1) + if self.lookahead: + c = torch.tanh(self.conv1(F.pad(features, [1, 1]))) + c = torch.tanh(self.conv2(F.pad(c, [2, 2]))) + else: + c = torch.tanh(self.conv1(F.pad(features, [2, 0]))) + c = torch.tanh(self.conv2(F.pad(c, [4, 0]))) + + c = c.permute(0, 2, 1) + + c, _ = self.gru(c, state) + + return c \ No newline at end of file diff --git a/dnn/torch/osce/models/silk_feature_net_pl.py b/dnn/torch/osce/models/silk_feature_net_pl.py new file mode 100644 index 00000000..ae37951c --- /dev/null +++ b/dnn/torch/osce/models/silk_feature_net_pl.py @@ -0,0 +1,90 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + + +import torch +from torch import nn +import torch.nn.functional as F + +from utils.complexity import _conv1d_flop_count + +class SilkFeatureNetPL(nn.Module): + """ feature net with partial lookahead """ + def __init__(self, + feature_dim=47, + num_channels=256, + hidden_feature_dim=64): + + super(SilkFeatureNetPL, self).__init__() + + self.feature_dim = feature_dim + self.num_channels = num_channels + self.hidden_feature_dim = hidden_feature_dim + + self.conv1 = nn.Conv1d(feature_dim, self.hidden_feature_dim, 1) + self.conv2 = nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2) + self.tconv = nn.ConvTranspose1d(num_channels, num_channels, 4, 4) + + self.gru = nn.GRU(num_channels, num_channels, batch_first=True) + + def flop_count(self, rate=200): + count = 0 + for conv in self.conv1, self.conv2, self.tconv: + count += _conv1d_flop_count(conv, rate) + + count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate + + return count + + + def forward(self, features, state=None): + """ features shape: (batch_size, num_frames, feature_dim) """ + + batch_size = features.size(0) + num_frames = features.size(1) + + if state is None: + state = torch.zeros((1, batch_size, self.num_channels), device=features.device) + + features = features.permute(0, 2, 1) + # dimensionality reduction + c = torch.tanh(self.conv1(features)) + + # frame accumulation + c = c.permute(0, 2, 1) + c = c.reshape(batch_size, num_frames // 4, -1).permute(0, 2, 1) + c = torch.tanh(self.conv2(F.pad(c, [1, 0]))) + + # upsampling + c = self.tconv(c) + c = c.permute(0, 2, 1) + + c, _ = self.gru(c, state) + + return c \ No newline at end of file diff --git a/dnn/torch/osce/test_model.py b/dnn/torch/osce/test_model.py new file mode 100644 index 00000000..616a0ec5 --- /dev/null +++ b/dnn/torch/osce/test_model.py @@ -0,0 +1,96 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import argparse + +import torch + +from scipy.io import wavfile + + +from models import model_dict +from utils.silk_features import load_inference_data +from utils import endoscopy + +debug = False +if debug: + args = type('dummy', (object,), + { + 'input' : 'testitems/all_0_orig.se', + 'checkpoint' : 'testout/checkpoints/checkpoint_epoch_5.pth', + 'output' : 'out.wav', + })() +else: + parser = argparse.ArgumentParser() + + parser.add_argument('input', type=str, help='path to folder with features and signals') + parser.add_argument('checkpoint', type=str, help='checkpoint file') + parser.add_argument('output', type=str, help='output file') + parser.add_argument('--debug', action='store_true', help='enables debug output') + + + args = parser.parse_args() + + +torch.set_num_threads(2) + +input_folder = args.input +checkpoint_file = args.checkpoint + + +output_file = args.output +if not output_file.endswith('.wav'): + output_file += '.wav' + +checkpoint = torch.load(checkpoint_file, map_location="cpu") + +# check model +if not 'name' in checkpoint['setup']['model']: + print(f'warning: did not find model name entry in setup, using pitchpostfilter per default') + model_name = 'pitchpostfilter' +else: + model_name = checkpoint['setup']['model']['name'] + +model = model_dict[model_name](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs']) + +model.load_state_dict(checkpoint['state_dict']) + +# generate model input +setup = checkpoint['setup'] +signal, features, periods, numbits = load_inference_data(input_folder, **setup['data']) + +if args.debug: + endoscopy.init() + +output = model.process(signal, features, periods, numbits, debug=args.debug) + +wavfile.write(output_file, 16000, output.cpu().numpy()) + +if args.debug: + endoscopy.close() diff --git a/dnn/torch/osce/train_model.py b/dnn/torch/osce/train_model.py new file mode 100644 index 00000000..6e2514b9 --- /dev/null +++ b/dnn/torch/osce/train_model.py @@ -0,0 +1,297 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os +import argparse +import sys + +import yaml + +try: + import git + has_git = True +except: + has_git = False + +import torch +from torch.optim.lr_scheduler import LambdaLR + +import numpy as np + +from scipy.io import wavfile + +import pesq + +from data import SilkEnhancementSet +from models import model_dict +from engine.engine import train_one_epoch, evaluate + + +from utils.silk_features import load_inference_data +from utils.misc import count_parameters + +from losses.stft_loss import MRSTFTLoss, MRLogMelLoss + + +parser = argparse.ArgumentParser() + +parser.add_argument('setup', type=str, help='setup yaml file') +parser.add_argument('output', type=str, help='output path') +parser.add_argument('--device', type=str, help='compute device', default=None) +parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None) +parser.add_argument('--testdata', type=str, help='path to features and signal for testing', default=None) +parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout') + +args = parser.parse_args() + + +torch.set_num_threads(4) + +with open(args.setup, 'r') as f: + setup = yaml.load(f.read(), yaml.FullLoader) + +checkpoint_prefix = 'checkpoint' +output_prefix = 'output' +setup_name = 'setup.yml' +output_file='out.txt' + + +# check model +if not 'name' in setup['model']: + print(f'warning: did not find model entry in setup, using default PitchPostFilter') + model_name = 'pitchpostfilter' +else: + model_name = setup['model']['name'] + +# prepare output folder +if os.path.exists(args.output): + print("warning: output folder exists") + + reply = input('continue? (y/n): ') + while reply not in {'y', 'n'}: + reply = input('continue? (y/n): ') + + if reply == 'n': + os._exit() +else: + os.makedirs(args.output, exist_ok=True) + +checkpoint_dir = os.path.join(args.output, 'checkpoints') +os.makedirs(checkpoint_dir, exist_ok=True) + +# add repo info to setup +if has_git: + working_dir = os.path.split(__file__)[0] + try: + repo = git.Repo(working_dir) + setup['repo'] = dict() + hash = repo.head.object.hexsha + urls = list(repo.remote().urls) + is_dirty = repo.is_dirty() + + if is_dirty: + print("warning: repo is dirty") + + setup['repo']['hash'] = hash + setup['repo']['urls'] = urls + setup['repo']['dirty'] = is_dirty + except: + has_git = False + +# dump setup +with open(os.path.join(args.output, setup_name), 'w') as f: + yaml.dump(setup, f) + +ref = None +if args.testdata is not None: + + testsignal, features, periods, numbits = load_inference_data(args.testdata, **setup['data']) + + inference_test = True + inference_folder = os.path.join(args.output, 'inference_test') + os.makedirs(os.path.join(args.output, 'inference_test'), exist_ok=True) + + try: + ref = np.fromfile(os.path.join(args.testdata, 'clean.s16'), dtype=np.int16) + except: + pass +else: + inference_test = False + +# training parameters +batch_size = setup['training']['batch_size'] +epochs = setup['training']['epochs'] +lr = setup['training']['lr'] +lr_decay_factor = setup['training']['lr_decay_factor'] + +# load training dataset +data_config = setup['data'] +data = SilkEnhancementSet(setup['dataset'], **data_config) + +# load validation dataset if given +if 'validation_dataset' in setup: + validation_data = SilkEnhancementSet(setup['validation_dataset'], **data_config) + + validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=8) + + run_validation = True +else: + run_validation = False + +# create model +model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs']) + +if args.initial_checkpoint is not None: + print(f"loading state dict from {args.initial_checkpoint}...") + chkpt = torch.load(args.initial_checkpoint, map_location='cpu') + model.load_state_dict(chkpt['state_dict']) + +# set compute device +if type(args.device) == type(None): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +else: + device = torch.device(args.device) + +# push model to device +model.to(device) + +# dataloader +dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=8) + +# optimizer is introduced to trainable parameters +parameters = [p for p in model.parameters() if p.requires_grad] +optimizer = torch.optim.Adam(parameters, lr=lr) + +# learning rate scheduler +scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x)) + +# loss +w_l1 = setup['training']['loss']['w_l1'] +w_lm = setup['training']['loss']['w_lm'] +w_slm = setup['training']['loss']['w_slm'] +w_sc = setup['training']['loss']['w_sc'] +w_logmel = setup['training']['loss']['w_logmel'] +w_wsc = setup['training']['loss']['w_wsc'] +w_xcorr = setup['training']['loss']['w_xcorr'] +w_sxcorr = setup['training']['loss']['w_sxcorr'] +w_l2 = setup['training']['loss']['w_l2'] + +w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2 + +stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device) +logmelloss = MRLogMelLoss().to(device) + +def xcorr_loss(y_true, y_pred): + dims = list(range(1, len(y_true.shape))) + + loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9) + + return torch.mean(loss) + +def td_l2_norm(y_true, y_pred): + dims = list(range(1, len(y_true.shape))) + + loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6) + + return loss.mean() + +def td_l1(y_true, y_pred, pow=0): + dims = list(range(1, len(y_true.shape))) + tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow) + + return torch.mean(tmp) + +def criterion(x, y): + + return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y) + + w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum + + + +# model checkpoint +checkpoint = { + 'setup' : setup, + 'state_dict' : model.state_dict(), + 'loss' : -1 +} + + + + +if not args.no_redirect: + print(f"re-directing output to {os.path.join(args.output, output_file)}") + sys.stdout = open(os.path.join(args.output, output_file), "w") + +print("summary:") + +print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters") +if hasattr(model, 'flop_count'): + print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS") + +if ref is not None: + noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16) + initial_mos = pesq.pesq(16000, ref, noisy, mode='wb') + print(f"initial MOS (PESQ): {initial_mos}") + +best_loss = 1e9 + +for ep in range(1, epochs + 1): + print(f"training epoch {ep}...") + new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler) + + + # save checkpoint + checkpoint['state_dict'] = model.state_dict() + checkpoint['loss'] = new_loss + + if run_validation: + print("running validation...") + validation_loss = evaluate(model, criterion, validation_dataloader, device) + checkpoint['validation_loss'] = validation_loss + + if validation_loss < best_loss: + torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth')) + best_loss = validation_loss + + if inference_test: + print("running inference test...") + out = model.process(testsignal, features, periods, numbits).cpu().numpy() + wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out) + if ref is not None: + mos = pesq.pesq(16000, ref, out, mode='wb') + print(f"MOS (PESQ): {mos}") + + + torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth')) + torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth')) + + + print() + +print('Done') diff --git a/dnn/torch/osce/utils/complexity.py b/dnn/torch/osce/utils/complexity.py new file mode 100644 index 00000000..79de22c5 --- /dev/null +++ b/dnn/torch/osce/utils/complexity.py @@ -0,0 +1,35 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +def _conv1d_flop_count(layer, rate): + return 2 * ((layer.in_channels + 1) * layer.out_channels * rate / layer.stride[0] ) * layer.kernel_size[0] + + +def _dense_flop_count(layer, rate): + return 2 * ((layer.in_features + 1) * layer.out_features * rate ) \ No newline at end of file diff --git a/dnn/torch/osce/utils/endoscopy.py b/dnn/torch/osce/utils/endoscopy.py new file mode 100644 index 00000000..141447e2 --- /dev/null +++ b/dnn/torch/osce/utils/endoscopy.py @@ -0,0 +1,234 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +""" module for inspecting models during inference """ + +import os + +import yaml +import matplotlib.pyplot as plt +import matplotlib.animation as animation + +import torch +import numpy as np + +# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}} +_state = dict() +_folder = 'endoscopy' + +def get_gru_gates(gru, input, state): + hidden_size = gru.hidden_size + + direct = torch.matmul(gru.weight_ih_l0, input.squeeze()) + recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze()) + + # reset gate + start, stop = 0 * hidden_size, 1 * hidden_size + reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop]) + + # update gate + start, stop = 1 * hidden_size, 2 * hidden_size + update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop]) + + # new gate + start, stop = 2 * hidden_size, 3 * hidden_size + new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop])) + + return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate} + + +def init(folder='endoscopy'): + """ sets up output folder for endoscopy data """ + + global _folder + _folder = folder + + if not os.path.exists(folder): + os.makedirs(folder) + else: + print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.") + +def write_data(key, data, fs): + """ appends data to previous data written under key """ + + global _state + + # convert to numpy if torch.Tensor is given + if isinstance(data, torch.Tensor): + data = data.detach().numpy() + + if not key in _state: + _state[key] = { + 'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'), + 'fs' : fs, + 'dim' : tuple(data.shape), + 'dtype' : str(data.dtype) + } + + with open(os.path.join(_folder, key + '.yml'), 'w') as f: + f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]})) + else: + if _state[key]['fs'] != fs: + raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}") + if _state[key]['dtype'] != str(data.dtype): + raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}") + if _state[key]['dim'] != tuple(data.shape): + raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}") + + _state[key]['fid'].write(data.tobytes()) + +def close(folder='endoscopy'): + """ clean up """ + for key in _state.keys(): + _state[key]['fid'].close() + + +def read_data(folder='endoscopy'): + """ retrieves written data as numpy arrays """ + + + keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')] + + return_dict = dict() + + for key in keys: + with open(os.path.join(folder, key + '.yml'), 'r') as f: + value = yaml.load(f.read(), yaml.FullLoader) + + with open(os.path.join(folder, key + '.bin'), 'rb') as f: + data = np.frombuffer(f.read(), dtype=value['dtype']) + + value['data'] = data.reshape((-1,) + value['dim']) + + return_dict[key] = value + + return return_dict + +def get_best_reshape(shape, target_ratio=1): + """ calculated the best 2d reshape of shape given the target ratio (rows/cols)""" + + if len(shape) > 1: + pixel_count = 1 + for s in shape: + pixel_count *= s + else: + pixel_count = shape[0] + + if pixel_count == 1: + return (1,) + + num_columns = int((pixel_count / target_ratio)**.5) + + while (pixel_count % num_columns): + num_columns -= 1 + + num_rows = pixel_count // num_columns + + return (num_rows, num_columns) + +def get_type_and_shape(shape): + + # can happen if data is one dimensional + if len(shape) == 0: + shape = (1,) + + # calculate pixel count + if len(shape) > 1: + pixel_count = 1 + for s in shape: + pixel_count *= s + else: + pixel_count = shape[0] + + if pixel_count == 1: + return 'plot', (1, ) + + # stay with shape if already 2-dimensional + if len(shape) == 2: + if (shape[0] != pixel_count) or (shape[1] != pixel_count): + return 'image', shape + + return 'image', get_best_reshape(shape) + +def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80): + + # determine plot setup + num_keys = len(data.keys()) + + num_rows = int((num_keys * 3/4) ** .5) + + num_cols = (num_keys + num_rows - 1) // num_rows + + fig, axs = plt.subplots(num_rows, num_cols) + fig.set_size_inches(num_cols * 5, num_rows * 5) + + display = dict() + + fs_max = max([val['fs'] for val in data.values()]) + + num_samples = max([val['data'].shape[0] for val in data.values()]) + + keys = sorted(data.keys()) + + # inspect data + for i, key in enumerate(keys): + axs[i // num_cols, i % num_cols].title.set_text(key) + + display[key] = dict() + + display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim']) + display[key]['down_factor'] = data[key]['fs'] / fs_max + + start_index = max(start_index, half_signal_window_length) + while stop_index < 0: + stop_index += num_samples + + stop_index = min(stop_index, num_samples - half_signal_window_length) + + # actual plotting + frames = [] + for index in range(start_index, stop_index): + ims = [] + for i, key in enumerate(keys): + feature_index = int(round(index * display[key]['down_factor'])) + + if display[key]['type'] == 'plot': + ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0]) + + elif display[key]['type'] == 'image': + ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True)) + + frames.append(ims) + + ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000) + + if not filename.endswith('.mp4'): + filename += '.mp4' + + ani.save(filename) \ No newline at end of file diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py new file mode 100644 index 00000000..b146240e --- /dev/null +++ b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py @@ -0,0 +1,236 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import torch +from torch import nn +import torch.nn.functional as F + +from utils.endoscopy import write_data + +class LimitedAdaptiveComb1d(nn.Module): + COUNTER = 1 + + def __init__(self, + kernel_size, + feature_dim, + frame_size=160, + overlap_size=40, + use_bias=True, + padding=None, + max_lag=256, + name=None, + gain_limit_db=10, + global_gain_limits_db=[-6, 6], + norm_p=2): + """ + + Parameters: + ----------- + + feature_dim : int + dimension of features from which kernels, biases and gains are computed + + frame_size : int, optional + frame size, defaults to 160 + + overlap_size : int, optional + overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame, defaults to 40 + + use_bias : bool, optional + if true, biases will be added to output channels. Defaults to True + + padding : List[int, int], optional + left and right padding. Defaults to [(kernel_size - 1) // 2, kernel_size - 1 - (kernel_size - 1) // 2] + + max_lag : int, optional + maximal pitch lag, defaults to 256 + + have_a0 : bool, optional + If true, the filter coefficient a0 will be learned as a positive gain (requires in_channels == out_channels). Otherwise, a0 is set to 0. Defaults to False + + name: str or None, optional + specifies a name attribute for the module. If None the name is auto generated as comb_1d_COUNT, where COUNT is an instance counter for LimitedAdaptiveComb1d + + """ + + super(LimitedAdaptiveComb1d, self).__init__() + + self.in_channels = 1 + self.out_channels = 1 + self.feature_dim = feature_dim + self.kernel_size = kernel_size + self.frame_size = frame_size + self.overlap_size = overlap_size + self.use_bias = use_bias + self.max_lag = max_lag + self.limit_db = gain_limit_db + self.norm_p = norm_p + + if name is None: + self.name = "limited_adaptive_comb1d_" + str(LimitedAdaptiveComb1d.COUNTER) + LimitedAdaptiveComb1d.COUNTER += 1 + else: + self.name = name + + # network for generating convolution weights + self.conv_kernel = nn.Linear(feature_dim, kernel_size) + + if self.use_bias: + self.conv_bias = nn.Linear(feature_dim,1) + + # comb filter gain + self.filter_gain = nn.Linear(feature_dim, 1) + self.log_gain_limit = gain_limit_db * 0.11512925464970229 + with torch.no_grad(): + self.filter_gain.bias[:] = max(0.1, 4 + self.log_gain_limit) + + self.global_filter_gain = nn.Linear(feature_dim, 1) + log_min, log_max = global_gain_limits_db[0] * 0.11512925464970229, global_gain_limits_db[1] * 0.11512925464970229 + self.filter_gain_a = (log_max - log_min) / 2 + self.filter_gain_b = (log_max + log_min) / 2 + + if type(padding) == type(None): + self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2] + else: + self.padding = padding + + self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False) + + def forward(self, x, features, lags, debug=False): + """ adaptive 1d convolution + + + Parameters: + ----------- + x : torch.tensor + input signal of shape (batch_size, in_channels, num_samples) + + feathres : torch.tensor + frame-wise features of shape (batch_size, num_frames, feature_dim) + + lags: torch.LongTensor + frame-wise lags for comb-filtering + + """ + + batch_size = x.size(0) + num_frames = features.size(1) + num_samples = x.size(2) + frame_size = self.frame_size + overlap_size = self.overlap_size + kernel_size = self.kernel_size + win1 = torch.flip(self.overlap_win, [0]) + win2 = self.overlap_win + + if num_samples // self.frame_size != num_frames: + raise ValueError('non matching sizes in AdaptiveConv1d.forward') + + conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size)) + conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=-1, keepdim=True)) + + if self.use_bias: + conv_biases = self.conv_bias(features).permute(0, 2, 1) + + conv_gains = torch.exp(- torch.relu(self.filter_gain(features).permute(0, 2, 1)) + self.log_gain_limit) + # calculate gains + global_conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.global_filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b) + + if debug and batch_size == 1: + key = self.name + "_gains" + write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size) + key = self.name + "_kernels" + write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size) + key = self.name + "_lags" + write_data(key, lags.detach().squeeze().cpu().numpy(), 16000 // self.frame_size) + key = self.name + "_global_conv_gains" + write_data(key, global_conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size) + + + # frame-wise convolution with overlap-add + output_frames = [] + overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device) + x = F.pad(x, self.padding) + x = F.pad(x, [self.max_lag, self.overlap_size]) + + idx = torch.arange(frame_size + kernel_size - 1 + overlap_size).to(x.device).view(1, 1, -1) + idx = torch.repeat_interleave(idx, batch_size, 0) + idx = torch.repeat_interleave(idx, self.in_channels, 1) + + + for i in range(num_frames): + + cidx = idx + i * frame_size + self.max_lag - lags[..., i].view(batch_size, 1, 1) + xx = torch.gather(x, -1, cidx).reshape((1, batch_size * self.in_channels, -1)) + + new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1) + + + if self.use_bias: + new_chunk = new_chunk + conv_biases[:, :, i : i + 1] + + offset = self.max_lag + self.padding[0] + new_chunk = global_conv_gains[:, :, i : i + 1] * (new_chunk * conv_gains[:, :, i : i + 1] + x[..., offset + i * frame_size : offset + (i + 1) * frame_size + overlap_size]) + + # overlapping part + output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2) + + # non-overlapping part + output_frames.append(new_chunk[:, :, overlap_size : frame_size]) + + # mem for next frame + overlap_mem = new_chunk[:, :, frame_size :] + + # concatenate chunks + output = torch.cat(output_frames, dim=-1) + + return output + + def flop_count(self, rate): + frame_rate = rate / self.frame_size + overlap = self.overlap_size + overhead = overlap / self.frame_size + + count = 0 + + # kernel computation and filtering + count += 2 * (frame_rate * self.feature_dim * self.kernel_size) + count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate) + count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels + + # bias computation + if self.use_bias: + count += 2 * (frame_rate * self.feature_dim) + rate * (1 + overhead) + + # a0 computation + count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels + + # windowing + count += overlap * frame_rate * 3 * self.out_channels + + return count diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py new file mode 100644 index 00000000..5992296f --- /dev/null +++ b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py @@ -0,0 +1,222 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import torch +from torch import nn +import torch.nn.functional as F + +from utils.endoscopy import write_data + +class LimitedAdaptiveConv1d(nn.Module): + COUNTER = 1 + + def __init__(self, + in_channels, + out_channels, + kernel_size, + feature_dim, + frame_size=160, + overlap_size=40, + use_bias=True, + padding=None, + name=None, + gain_limits_db=[-6, 6], + shape_gain_db=0, + norm_p=2): + """ + + Parameters: + ----------- + + in_channels : int + number of input channels + + out_channels : int + number of output channels + + feature_dim : int + dimension of features from which kernels, biases and gains are computed + + frame_size : int + frame size + + overlap_size : int + overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame + + use_bias : bool + if true, biases will be added to output channels + + + padding : List[int, int] + + """ + + super(LimitedAdaptiveConv1d, self).__init__() + + + + self.in_channels = in_channels + self.out_channels = out_channels + self.feature_dim = feature_dim + self.kernel_size = kernel_size + self.frame_size = frame_size + self.overlap_size = overlap_size + self.use_bias = use_bias + self.gain_limits_db = gain_limits_db + self.shape_gain_db = shape_gain_db + self.norm_p = norm_p + + if name is None: + self.name = "limited_adaptive_conv1d_" + str(LimitedAdaptiveConv1d.COUNTER) + LimitedAdaptiveConv1d.COUNTER += 1 + else: + self.name = name + + # network for generating convolution weights + self.conv_kernel = nn.Linear(feature_dim, in_channels * out_channels * kernel_size) + + if self.use_bias: + self.conv_bias = nn.Linear(feature_dim, out_channels) + + self.shape_gain = min(1, 10**(shape_gain_db / 20)) + + self.filter_gain = nn.Linear(feature_dim, out_channels) + log_min, log_max = gain_limits_db[0] * 0.11512925464970229, gain_limits_db[1] * 0.11512925464970229 + self.filter_gain_a = (log_max - log_min) / 2 + self.filter_gain_b = (log_max + log_min) / 2 + + if type(padding) == type(None): + self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2] + else: + self.padding = padding + + self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False) + + + def flop_count(self, rate): + frame_rate = rate / self.frame_size + overlap = self.overlap_size + overhead = overlap / self.frame_size + + count = 0 + + # kernel computation and filtering + count += 2 * (frame_rate * self.feature_dim * self.kernel_size) + count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate) + + # bias computation + if self.use_bias: + count += 2 * (frame_rate * self.feature_dim) + rate * (1 + overhead) + + # gain computation + + count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels + + # windowing + count += 3 * overlap * frame_rate * self.out_channels + + return count + + def forward(self, x, features, debug=False): + """ adaptive 1d convolution + + + Parameters: + ----------- + x : torch.tensor + input signal of shape (batch_size, in_channels, num_samples) + + feathres : torch.tensor + frame-wise features of shape (batch_size, num_frames, feature_dim) + + """ + + batch_size = x.size(0) + num_frames = features.size(1) + num_samples = x.size(2) + frame_size = self.frame_size + overlap_size = self.overlap_size + kernel_size = self.kernel_size + win1 = torch.flip(self.overlap_win, [0]) + win2 = self.overlap_win + + if num_samples // self.frame_size != num_frames: + raise ValueError('non matching sizes in AdaptiveConv1d.forward') + + conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size)) + + # normalize kernels (TODO: switch to L1 and normalize over kernel and input channel dimension) + conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=[-2, -1], keepdim=True)) + + # limit shape + id_kernels = torch.zeros_like(conv_kernels) + id_kernels[..., self.padding[1]] = 1 + + conv_kernels = self.shape_gain * conv_kernels + (1 - self.shape_gain) * id_kernels + + if self.use_bias: + conv_biases = self.conv_bias(features).permute(0, 2, 1) + + # calculate gains + conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b) + if debug and batch_size == 1: + key = self.name + "_gains" + write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size) + key = self.name + "_kernels" + write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size) + + + # frame-wise convolution with overlap-add + output_frames = [] + overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device) + x = F.pad(x, self.padding) + x = F.pad(x, [0, self.overlap_size]) + + for i in range(num_frames): + xx = x[:, :, i * frame_size : (i + 1) * frame_size + kernel_size - 1 + overlap_size].reshape((1, batch_size * self.in_channels, -1)) + new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1) + + if self.use_bias: + new_chunk = new_chunk + conv_biases[:, :, i : i + 1] + + new_chunk = new_chunk * conv_gains[:, :, i : i + 1] + + # overlapping part + output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2) + + # non-overlapping part + output_frames.append(new_chunk[:, :, overlap_size : frame_size]) + + # mem for next frame + overlap_mem = new_chunk[:, :, frame_size :] + + # concatenate chunks + output = torch.cat(output_frames, dim=-1) + + return output \ No newline at end of file diff --git a/dnn/torch/osce/utils/layers/pitch_auto_correlator.py b/dnn/torch/osce/utils/layers/pitch_auto_correlator.py new file mode 100644 index 00000000..ef58ae8e --- /dev/null +++ b/dnn/torch/osce/utils/layers/pitch_auto_correlator.py @@ -0,0 +1,84 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import torch +from torch import nn +import torch.nn.functional as F + + +class PitchAutoCorrelator(nn.Module): + def __init__(self, + frame_size=80, + pitch_min=32, + pitch_max=300, + radius=2): + + super().__init__() + + self.frame_size = frame_size + self.pitch_min = pitch_min + self.pitch_max = pitch_max + self.radius = radius + + + def forward(self, x, periods): + # x of shape (batch_size, channels, num_samples) + # periods of shape (batch_size, num_frames) + + num_frames = periods.size(1) + batch_size = periods.size(0) + num_samples = self.frame_size * num_frames + channels = x.size(1) + + assert num_samples == x.size(-1) + + range = torch.arange(-self.radius, self.radius + 1, device=x.device) + idx = torch.arange(self.frame_size * num_frames, device=x.device) + p_up = torch.repeat_interleave(periods, self.frame_size, 1) + lookup = idx + self.pitch_max - p_up + lookup = lookup.unsqueeze(-1) + range + lookup = lookup.unsqueeze(1) + + # padding + x_pad = F.pad(x, [self.pitch_max, 0]) + x_ext = torch.repeat_interleave(x_pad.unsqueeze(-1), 2 * self.radius + 1, -1) + + # framing + x_select = torch.gather(x_ext, 2, lookup) + x_frames = x_pad[..., self.pitch_max : ].reshape(batch_size, channels, num_frames, self.frame_size, 1) + lag_frames = x_select.reshape(batch_size, 1, num_frames, self.frame_size, -1) + + # calculate auto-correlation + dotp = torch.sum(x_frames * lag_frames, dim=-2) + frame_nrg = torch.sum(x_frames * x_frames, dim=-2) + lag_frame_nrg = torch.sum(lag_frames * lag_frames, dim=-2) + + acorr = dotp / torch.sqrt(frame_nrg * lag_frame_nrg + 1e-9) + + return acorr diff --git a/dnn/torch/osce/utils/misc.py b/dnn/torch/osce/utils/misc.py new file mode 100644 index 00000000..d4c03478 --- /dev/null +++ b/dnn/torch/osce/utils/misc.py @@ -0,0 +1,42 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import torch + +def count_parameters(model, verbose=False): + total = 0 + for name, p in model.named_parameters(): + count = torch.ones_like(p).sum().item() + + if verbose: + print(f"{name}: {count} parameters") + + total += count + + return total \ No newline at end of file diff --git a/dnn/torch/osce/utils/pitch.py b/dnn/torch/osce/utils/pitch.py new file mode 100644 index 00000000..32b3bbf8 --- /dev/null +++ b/dnn/torch/osce/utils/pitch.py @@ -0,0 +1,121 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import numpy as np + +def hangover(lags, num_frames=10): + lags = lags.copy() + count = 0 + last_lag = 0 + + for i in range(len(lags)): + lag = lags[i] + + if lag == 0: + if count < num_frames: + lags[i] = last_lag + count += 1 + else: + count = 0 + + return lags + + +def smooth_pitch_lags(lags, d=2): + + assert d < 4 + + num_silk_frames = len(lags) // 4 + + smoothed_lags = lags.copy() + + tmp = np.arange(1, d+1) + kernel = np.concatenate((tmp, [d+1], tmp[::-1]), dtype=np.float32) + kernel = kernel / np.sum(kernel) + + last = lags[0:d][::-1] + for i in range(num_silk_frames): + frame = lags[i * 4: (i+1) * 4] + + if np.max(np.abs(frame)) == 0: + last = frame[4-d:] + continue + + if i == num_silk_frames - 1: + next = frame[4-d:][::-1] + else: + next = lags[(i+1) * 4 : (i+1) * 4 + d] + + if np.max(np.abs(next)) == 0: + next = frame[4-d:][::-1] + + if np.max(np.abs(last)) == 0: + last = frame[0:d][::-1] + + smoothed_frame = np.convolve(np.concatenate((last, frame, next), dtype=np.float32), kernel, mode='valid') + + smoothed_lags[i * 4: (i+1) * 4] = np.round(smoothed_frame) + + last = frame[4-d:] + + return smoothed_lags + +def calculate_acorr_window(x, frame_size, lags, history=None, max_lag=300, radius=2, add_double_lag_acorr=False, no_pitch_threshold=32): + eps = 1e-9 + + lag_multiplier = 2 if add_double_lag_acorr else 1 + + if history is None: + history = np.zeros(lag_multiplier * max_lag + radius, dtype=x.dtype) + + offset = len(history) + + assert offset >= max_lag + radius + assert len(x) % frame_size == 0 + + num_frames = len(x) // frame_size + lags = lags.copy() + + x_ext = np.concatenate((history, x), dtype=x.dtype) + + d = radius + num_acorrs = 2 * d + 1 + acorrs = np.zeros((num_frames, lag_multiplier * num_acorrs), dtype=x.dtype) + + for idx in range(num_frames): + lag = lags[idx].item() + frame = x_ext[offset + idx * frame_size : offset + (idx + 1) * frame_size] + + for k in range(lag_multiplier): + lag1 = (k + 1) * lag if lag >= no_pitch_threshold else lag + for j in range(num_acorrs): + past = x_ext[offset + idx * frame_size - lag1 + j - d : offset + (idx + 1) * frame_size - lag1 + j - d] + acorrs[idx, j + k * num_acorrs] = np.dot(frame, past) / np.sqrt(np.dot(frame, frame) * np.dot(past, past) + eps) + + return acorrs, lags \ No newline at end of file diff --git a/dnn/torch/osce/utils/silk_features.py b/dnn/torch/osce/utils/silk_features.py new file mode 100644 index 00000000..071a6c26 --- /dev/null +++ b/dnn/torch/osce/utils/silk_features.py @@ -0,0 +1,151 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + + +import os + +import numpy as np +import torch + +import scipy + +from utils.pitch import hangover, calculate_acorr_window +from utils.spec import create_filter_bank, cepstrum, log_spectrum, log_spectrum_from_lpc + +def spec_from_lpc(a, n_fft=128, eps=1e-9): + order = a.shape[-1] + assert order + 1 < n_fft + + x = np.zeros((*a.shape[:-1], n_fft )) + x[..., 0] = 1 + x[..., 1:1 + order] = -a + + X = np.fft.fft(x, axis=-1) + X = np.abs(X[..., :n_fft//2 + 1]) ** 2 + + S = 1 / (X + eps) + + return S + +def silk_feature_factory(no_pitch_value=256, + acorr_radius=2, + pitch_hangover=8, + num_bands_clean_spec=64, + num_bands_noisy_spec=18, + noisy_spec_scale='opus', + noisy_apply_dct=True, + add_offset=False, + add_double_lag_acorr=False + ): + + w = scipy.signal.windows.cosine(320) + fb_clean_spec = create_filter_bank(num_bands_clean_spec, 320, scale='erb', round_center_bins=True, normalize=True) + fb_noisy_spec = create_filter_bank(num_bands_noisy_spec, 320, scale=noisy_spec_scale, round_center_bins=True, normalize=True) + + def create_features(noisy, noisy_history, lpcs, gains, ltps, periods, offsets): + + periods = periods.copy() + + if pitch_hangover > 0: + periods = hangover(periods, num_frames=pitch_hangover) + + periods[periods == 0] = no_pitch_value + + clean_spectrum = 0.3 * log_spectrum_from_lpc(lpcs, fb=fb_clean_spec, n_fft=320) + + if noisy_apply_dct: + noisy_cepstrum = np.repeat( + cepstrum(np.concatenate((noisy_history[-160:], noisy), dtype=np.float32), 320, fb_noisy_spec, w), 2, 0) + else: + noisy_cepstrum = np.repeat( + log_spectrum(np.concatenate((noisy_history[-160:], noisy), dtype=np.float32), 320, fb_noisy_spec, w), 2, 0) + + log_gains = np.log(gains + 1e-9).reshape(-1, 1) + + acorr, _ = calculate_acorr_window(noisy, 80, periods, noisy_history, radius=acorr_radius, add_double_lag_acorr=add_double_lag_acorr) + + if add_offset: + features = np.concatenate((clean_spectrum, noisy_cepstrum, acorr, ltps, log_gains, offsets.reshape(-1, 1)), axis=-1, dtype=np.float32) + else: + features = np.concatenate((clean_spectrum, noisy_cepstrum, acorr, ltps, log_gains), axis=-1, dtype=np.float32) + + return features, periods.astype(np.int64) + + return create_features + + + +def load_inference_data(path, + no_pitch_value=256, + skip=92, + preemph=0.85, + acorr_radius=2, + pitch_hangover=8, + num_bands_clean_spec=64, + num_bands_noisy_spec=18, + noisy_spec_scale='opus', + noisy_apply_dct=True, + add_offset=False, + add_double_lag_acorr=False, + **kwargs): + + print(f"[load_inference_data]: ignoring keyword arguments {kwargs.keys()}...") + + lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16) + ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5) + gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32) + periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16) + num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32).astype(np.float32).reshape(-1, 1) + num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32).reshape(-1, 1) + offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32) + + # load signal, add back delay and pre-emphasize + signal = np.fromfile(os.path.join(path, 'noisy.s16'), dtype=np.int16).astype(np.float32) / (2 ** 15) + signal = np.concatenate((np.zeros(skip, dtype=np.float32), signal), dtype=np.float32) + + create_features = silk_feature_factory(no_pitch_value, acorr_radius, pitch_hangover, num_bands_clean_spec, num_bands_noisy_spec, noisy_spec_scale, noisy_apply_dct, add_offset, add_double_lag_acorr) + + num_frames = min((len(signal) // 320) * 4, len(lpcs)) + signal = signal[: num_frames * 80] + lpcs = lpcs[: num_frames] + ltps = ltps[: num_frames] + gains = gains[: num_frames] + periods = periods[: num_frames] + num_bits = num_bits[: num_frames // 4] + num_bits_smooth = num_bits[: num_frames // 4] + offsets = offsets[: num_frames] + + numbits = np.repeat(np.concatenate((num_bits, num_bits_smooth), axis=-1, dtype=np.float32), 4, axis=0) + + features, periods = create_features(signal, np.zeros(350, dtype=signal.dtype), lpcs, gains, ltps, periods, offsets) + + if preemph > 0: + signal[1:] -= preemph * signal[:-1] + + return torch.from_numpy(signal), torch.from_numpy(features), torch.from_numpy(periods), torch.from_numpy(numbits) diff --git a/dnn/torch/osce/utils/spec.py b/dnn/torch/osce/utils/spec.py new file mode 100644 index 00000000..7e41d84e --- /dev/null +++ b/dnn/torch/osce/utils/spec.py @@ -0,0 +1,194 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import math as m +import numpy as np +import scipy + +def erb(f): + return 24.7 * (4.37 * f + 1) + +def inv_erb(e): + return (e / 24.7 - 1) / 4.37 + +def bark(f): + return 6 * m.asinh(f/600) + +def inv_bark(b): + return 600 * m.sinh(b / 6) + + +scale_dict = { + 'bark': [bark, inv_bark], + 'erb': [erb, inv_erb] +} + +def create_filter_bank(num_bands, n_fft=320, fs=16000, scale='bark', round_center_bins=False, return_upper=False, normalize=False): + + f0 = 0 + num_bins = n_fft // 2 + 1 + f1 = fs / n_fft * (num_bins - 1) + fstep = fs / n_fft + + if scale == 'opus': + bins_5ms = [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40] + fac = 1000 * n_fft / fs / 5 + if num_bands != 18: + print("warning: requested Opus filter bank with num_bands != 18. Adjusting num_bands.") + num_bands = 18 + center_bins = np.array([fac * bin for bin in bins_5ms]) + else: + to_scale, from_scale = scale_dict[scale] + + s0 = to_scale(f0) + s1 = to_scale(f1) + + center_freqs = np.array([f0] + [from_scale(s0 + i * (s1 - s0) / (num_bands)) for i in range(1, num_bands - 1)] + [f1]) + center_bins = (center_freqs - f0) / fstep + + if round_center_bins: + center_bins = np.round(center_bins) + + filter_bank = np.zeros((num_bands, num_bins)) + + band = 0 + for bin in range(num_bins): + # update band index + if bin > center_bins[band + 1]: + band += 1 + + # calculate filter coefficients + frac = (center_bins[band + 1] - bin) / (center_bins[band + 1] - center_bins[band]) + filter_bank[band][bin] = frac + filter_bank[band + 1][bin] = 1 - frac + + if return_upper: + extend = n_fft - num_bins + filter_bank = np.concatenate((filter_bank, np.fliplr(filter_bank[:, 1:extend+1])), axis=1) + + if normalize: + filter_bank = filter_bank / np.sum(filter_bank, axis=1).reshape(-1, 1) + + return filter_bank + + +def compressed_log_spec(pspec): + + lpspec = np.zeros_like(pspec) + num_bands = pspec.shape[-1] + + log_max = -2 + follow = -2 + + for i in range(num_bands): + tmp = np.log10(pspec[i] + 1e-9) + tmp = max(log_max, max(follow - 2.5, tmp)) + lpspec[i] = tmp + log_max = max(log_max, tmp) + follow = max(follow - 2.5, tmp) + + return lpspec + +def log_spectrum_from_lpc(a, fb=None, n_fft=320, eps=1e-9, gamma=1, compress=False, power=1): + """ calculates cepstrum from SILK lpcs """ + order = a.shape[-1] + assert order + 1 < n_fft + + a = a * (gamma ** (1 + np.arange(order))).astype(np.float32) + + x = np.zeros((*a.shape[:-1], n_fft )) + x[..., 0] = 1 + x[..., 1:1 + order] = -a + + X = np.fft.fft(x, axis=-1) + X = np.abs(X[..., :n_fft//2 + 1]) ** power + + S = 1 / (X + eps) + + if fb is None: + Sf = S + else: + Sf = np.matmul(S, fb.T) + + if compress: + Sf = np.apply_along_axis(compressed_log_spec, -1, Sf) + else: + Sf = np.log(Sf + eps) + + return Sf + +def cepstrum_from_lpc(a, fb=None, n_fft=320, eps=1e-9, gamma=1, compress=False): + """ calculates cepstrum from SILK lpcs """ + + Sf = log_spectrum_from_lpc(a, fb, n_fft, eps, gamma, compress) + + cepstrum = scipy.fftpack.dct(Sf, 2, norm='ortho') + + return cepstrum + + + +def log_spectrum(x, frame_size, fb=None, window=None, power=1): + """ calculate cepstrum on 50% overlapping frames """ + + assert(2*len(x)) % frame_size == 0 + assert frame_size % 2 == 0 + + n = len(x) + num_even = n // frame_size + num_odd = (n - frame_size // 2) // frame_size + num_bins = frame_size // 2 + 1 + + x_even = x[:num_even * frame_size].reshape(-1, frame_size) + x_odd = x[frame_size//2 : frame_size//2 + frame_size * num_odd].reshape(-1, frame_size) + + x_unfold = np.empty((x_even.size + x_odd.size), dtype=x.dtype).reshape((-1, frame_size)) + x_unfold[::2, :] = x_even + x_unfold[1::2, :] = x_odd + + if window is not None: + x_unfold *= window.reshape(1, -1) + + X = np.abs(np.fft.fft(x_unfold, n=frame_size, axis=-1))[:, :num_bins] ** power + + if fb is not None: + X = np.matmul(X, fb.T) + + + return np.log(X + 1e-9) + + +def cepstrum(x, frame_size, fb=None, window=None): + """ calculate cepstrum on 50% overlapping frames """ + + X = log_spectrum(x, frame_size, fb, window) + + cepstrum = scipy.fftpack.dct(X, 2, norm='ortho') + + return cepstrum \ No newline at end of file diff --git a/dnn/torch/osce/utils/templates.py b/dnn/torch/osce/utils/templates.py new file mode 100644 index 00000000..1232710f --- /dev/null +++ b/dnn/torch/osce/utils/templates.py @@ -0,0 +1,92 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + + +setup_dict = dict() + +lace_setup = { + 'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training', + 'validation_dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/validation', + 'model': { + 'name': 'lace', + 'args': [], + 'kwargs': { + 'comb_gain_limit_db': 10, + 'cond_dim': 128, + 'conv_gain_limits_db': [-12, 12], + 'global_gain_limits_db': [-6, 6], + 'hidden_feature_dim': 96, + 'kernel_size': 15, + 'num_features': 93, + 'numbits_embedding_dim': 8, + 'numbits_range': [50, 650], + 'partial_lookahead': True, + 'pitch_embedding_dim': 64, + 'pitch_max': 300, + 'preemph': 0.85, + 'skip': 91 + } + }, + 'data': { + 'frames_per_sample': 100, + 'no_pitch_value': 7, + 'preemph': 0.85, + 'skip': 91, + 'pitch_hangover': 8, + 'acorr_radius': 2, + 'num_bands_clean_spec': 64, + 'num_bands_noisy_spec': 18, + 'noisy_spec_scale': 'opus', + 'pitch_hangover': 8, + }, + 'training': { + 'batch_size': 256, + 'lr': 5.e-4, + 'lr_decay_factor': 2.5e-5, + 'epochs': 50, + 'frames_per_sample': 50, + 'loss': { + 'w_l1': 0, + 'w_lm': 0, + 'w_logmel': 0, + 'w_sc': 0, + 'w_wsc': 0, + 'w_xcorr': 0, + 'w_sxcorr': 1, + 'w_l2': 10, + 'w_slm': 2 + } + } +} + + + +setup_dict = { + 'lace': lace_setup, +}