mirror of
https://github.com/xiph/opus.git
synced 2025-05-23 11:49:12 +00:00
added more enhancement stuff
Signed-off-by: Jan Buethe <jbuethe@amazon.de>
This commit is contained in:
parent
7b8ba143f1
commit
2f290d32ed
24 changed files with 3511 additions and 108 deletions
225
dnn/torch/osce/data/lpcnet_vocoding_dataset.py
Normal file
225
dnn/torch/osce/data/lpcnet_vocoding_dataset.py
Normal file
|
@ -0,0 +1,225 @@
|
|||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
""" Dataset for LPCNet training """
|
||||
import os
|
||||
|
||||
import yaml
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
scale = 255.0/32768.0
|
||||
scale_1 = 32768.0/255.0
|
||||
def ulaw2lin(u):
|
||||
u = u - 128
|
||||
s = np.sign(u)
|
||||
u = np.abs(u)
|
||||
return s*scale_1*(np.exp(u/128.*np.log(256))-1)
|
||||
|
||||
|
||||
def lin2ulaw(x):
|
||||
s = np.sign(x)
|
||||
x = np.abs(x)
|
||||
u = (s*(128*np.log(1+scale*x)/np.log(256)))
|
||||
u = np.clip(128 + np.round(u), 0, 255)
|
||||
return u
|
||||
|
||||
|
||||
def run_lpc(signal, lpcs, frame_length=160):
|
||||
num_frames, lpc_order = lpcs.shape
|
||||
|
||||
prediction = np.concatenate(
|
||||
[- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]
|
||||
)
|
||||
error = signal[lpc_order :] - prediction
|
||||
|
||||
return prediction, error
|
||||
|
||||
class LPCNetVocodingDataset(Dataset):
|
||||
def __init__(self,
|
||||
path_to_dataset,
|
||||
features=['cepstrum', 'periods', 'pitch_corr'],
|
||||
target='signal',
|
||||
frames_per_sample=100,
|
||||
feature_history=0,
|
||||
feature_lookahead=0,
|
||||
lpc_gamma=1):
|
||||
|
||||
super().__init__()
|
||||
|
||||
# load dataset info
|
||||
self.path_to_dataset = path_to_dataset
|
||||
with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f:
|
||||
dataset = yaml.load(f, yaml.FullLoader)
|
||||
|
||||
# dataset version
|
||||
self.version = dataset['version']
|
||||
if self.version == 1:
|
||||
self.getitem = self.getitem_v1
|
||||
elif self.version == 2:
|
||||
self.getitem = self.getitem_v2
|
||||
else:
|
||||
raise ValueError(f"dataset version {self.version} unknown")
|
||||
|
||||
# features
|
||||
self.feature_history = feature_history
|
||||
self.feature_lookahead = feature_lookahead
|
||||
self.frame_offset = 2 + self.feature_history
|
||||
self.frames_per_sample = frames_per_sample
|
||||
self.input_features = features
|
||||
self.feature_frame_layout = dataset['feature_frame_layout']
|
||||
self.lpc_gamma = lpc_gamma
|
||||
|
||||
# load feature file
|
||||
self.feature_file = os.path.join(path_to_dataset, dataset['feature_file'])
|
||||
self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype'])
|
||||
self.feature_frame_length = dataset['feature_frame_length']
|
||||
|
||||
assert len(self.features) % self.feature_frame_length == 0
|
||||
self.features = self.features.reshape((-1, self.feature_frame_length))
|
||||
|
||||
# derive number of samples is dataset
|
||||
self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1 - 2) // self.frames_per_sample
|
||||
|
||||
# signals
|
||||
self.frame_length = dataset['frame_length']
|
||||
self.signal_frame_layout = dataset['signal_frame_layout']
|
||||
self.target = target
|
||||
|
||||
# load signals
|
||||
self.signal_file = os.path.join(path_to_dataset, dataset['signal_file'])
|
||||
self.signals = np.memmap(self.signal_file, dtype=dataset['signal_dtype'])
|
||||
self.signal_frame_length = dataset['signal_frame_length']
|
||||
self.signals = self.signals.reshape((-1, self.signal_frame_length))
|
||||
assert len(self.signals) == len(self.features) * self.frame_length
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.getitem(index)
|
||||
|
||||
def getitem_v2(self, index):
|
||||
sample = dict()
|
||||
|
||||
# extract features
|
||||
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
|
||||
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
|
||||
|
||||
for feature in self.input_features:
|
||||
feature_start, feature_stop = self.feature_frame_layout[feature]
|
||||
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
|
||||
|
||||
# convert periods
|
||||
if 'periods' in self.input_features:
|
||||
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
|
||||
|
||||
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
|
||||
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
|
||||
|
||||
# last_signal and signal are always expected to be there
|
||||
sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
|
||||
sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']]
|
||||
|
||||
# calculate prediction and error if lpc coefficients present and prediction not given
|
||||
if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout:
|
||||
# lpc coefficients with one frame lookahead
|
||||
# frame positions (start one frame early for past excitation)
|
||||
frame_start = self.frame_offset + self.frames_per_sample * index - 1
|
||||
frame_stop = self.frame_offset + self.frames_per_sample * (index + 1)
|
||||
|
||||
# feature positions
|
||||
lpc_start, lpc_stop = self.feature_frame_layout['lpc']
|
||||
lpc_order = lpc_stop - lpc_start
|
||||
lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop]
|
||||
|
||||
# LPC weighting
|
||||
lpc_order = lpc_stop - lpc_start
|
||||
weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)])
|
||||
lpcs = lpcs * weights
|
||||
|
||||
# signal position (lpc_order samples as history)
|
||||
signal_start = frame_start * self.frame_length - lpc_order + 1
|
||||
signal_stop = frame_stop * self.frame_length + 1
|
||||
noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
|
||||
clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']]
|
||||
|
||||
noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length)
|
||||
|
||||
# extract signals
|
||||
offset = self.frame_length
|
||||
sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample]
|
||||
sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample]
|
||||
# calculate error between real signal and noisy prediction
|
||||
|
||||
|
||||
sample['error'] = sample['signal'] - sample['prediction']
|
||||
|
||||
|
||||
# concatenate features
|
||||
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
|
||||
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
|
||||
target = torch.FloatTensor(sample[self.target]) / 2**15
|
||||
periods = torch.LongTensor(sample['periods'])
|
||||
|
||||
return {'features' : features, 'periods' : periods, 'target' : target}
|
||||
|
||||
def getitem_v1(self, index):
|
||||
sample = dict()
|
||||
|
||||
# extract features
|
||||
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
|
||||
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
|
||||
|
||||
for feature in self.input_features:
|
||||
feature_start, feature_stop = self.feature_frame_layout[feature]
|
||||
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
|
||||
|
||||
# convert periods
|
||||
if 'periods' in self.input_features:
|
||||
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
|
||||
|
||||
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
|
||||
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
|
||||
|
||||
# last_signal and signal are always expected to be there
|
||||
for signal_name, index in self.signal_frame_layout.items():
|
||||
sample[signal_name] = self.signals[signal_start : signal_stop, index]
|
||||
|
||||
# concatenate features
|
||||
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
|
||||
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
|
||||
signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1)
|
||||
target = torch.LongTensor(sample[self.target])
|
||||
periods = torch.LongTensor(sample['periods'])
|
||||
|
||||
return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset_length
|
Loading…
Add table
Add a link
Reference in a new issue