opus/dnn/torch/osce/data/lpcnet_vocoding_dataset.py
Jan Buethe 2f290d32ed
added more enhancement stuff
Signed-off-by: Jan Buethe <jbuethe@amazon.de>
2023-09-12 14:50:24 +02:00

225 lines
9.6 KiB
Python

"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
""" Dataset for LPCNet training """
import os
import yaml
import torch
import numpy as np
from torch.utils.data import Dataset
scale = 255.0/32768.0
scale_1 = 32768.0/255.0
def ulaw2lin(u):
u = u - 128
s = np.sign(u)
u = np.abs(u)
return s*scale_1*(np.exp(u/128.*np.log(256))-1)
def lin2ulaw(x):
s = np.sign(x)
x = np.abs(x)
u = (s*(128*np.log(1+scale*x)/np.log(256)))
u = np.clip(128 + np.round(u), 0, 255)
return u
def run_lpc(signal, lpcs, frame_length=160):
num_frames, lpc_order = lpcs.shape
prediction = np.concatenate(
[- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]
)
error = signal[lpc_order :] - prediction
return prediction, error
class LPCNetVocodingDataset(Dataset):
def __init__(self,
path_to_dataset,
features=['cepstrum', 'periods', 'pitch_corr'],
target='signal',
frames_per_sample=100,
feature_history=0,
feature_lookahead=0,
lpc_gamma=1):
super().__init__()
# load dataset info
self.path_to_dataset = path_to_dataset
with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f:
dataset = yaml.load(f, yaml.FullLoader)
# dataset version
self.version = dataset['version']
if self.version == 1:
self.getitem = self.getitem_v1
elif self.version == 2:
self.getitem = self.getitem_v2
else:
raise ValueError(f"dataset version {self.version} unknown")
# features
self.feature_history = feature_history
self.feature_lookahead = feature_lookahead
self.frame_offset = 2 + self.feature_history
self.frames_per_sample = frames_per_sample
self.input_features = features
self.feature_frame_layout = dataset['feature_frame_layout']
self.lpc_gamma = lpc_gamma
# load feature file
self.feature_file = os.path.join(path_to_dataset, dataset['feature_file'])
self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype'])
self.feature_frame_length = dataset['feature_frame_length']
assert len(self.features) % self.feature_frame_length == 0
self.features = self.features.reshape((-1, self.feature_frame_length))
# derive number of samples is dataset
self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1 - 2) // self.frames_per_sample
# signals
self.frame_length = dataset['frame_length']
self.signal_frame_layout = dataset['signal_frame_layout']
self.target = target
# load signals
self.signal_file = os.path.join(path_to_dataset, dataset['signal_file'])
self.signals = np.memmap(self.signal_file, dtype=dataset['signal_dtype'])
self.signal_frame_length = dataset['signal_frame_length']
self.signals = self.signals.reshape((-1, self.signal_frame_length))
assert len(self.signals) == len(self.features) * self.frame_length
def __getitem__(self, index):
return self.getitem(index)
def getitem_v2(self, index):
sample = dict()
# extract features
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
for feature in self.input_features:
feature_start, feature_stop = self.feature_frame_layout[feature]
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
# convert periods
if 'periods' in self.input_features:
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
# last_signal and signal are always expected to be there
sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']]
# calculate prediction and error if lpc coefficients present and prediction not given
if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout:
# lpc coefficients with one frame lookahead
# frame positions (start one frame early for past excitation)
frame_start = self.frame_offset + self.frames_per_sample * index - 1
frame_stop = self.frame_offset + self.frames_per_sample * (index + 1)
# feature positions
lpc_start, lpc_stop = self.feature_frame_layout['lpc']
lpc_order = lpc_stop - lpc_start
lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop]
# LPC weighting
lpc_order = lpc_stop - lpc_start
weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)])
lpcs = lpcs * weights
# signal position (lpc_order samples as history)
signal_start = frame_start * self.frame_length - lpc_order + 1
signal_stop = frame_stop * self.frame_length + 1
noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']]
noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length)
# extract signals
offset = self.frame_length
sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample]
sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample]
# calculate error between real signal and noisy prediction
sample['error'] = sample['signal'] - sample['prediction']
# concatenate features
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
target = torch.FloatTensor(sample[self.target]) / 2**15
periods = torch.LongTensor(sample['periods'])
return {'features' : features, 'periods' : periods, 'target' : target}
def getitem_v1(self, index):
sample = dict()
# extract features
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
for feature in self.input_features:
feature_start, feature_stop = self.feature_frame_layout[feature]
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
# convert periods
if 'periods' in self.input_features:
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
# last_signal and signal are always expected to be there
for signal_name, index in self.signal_frame_layout.items():
sample[signal_name] = self.signals[signal_start : signal_stop, index]
# concatenate features
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1)
target = torch.LongTensor(sample[self.target])
periods = torch.LongTensor(sample['periods'])
return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
def __len__(self):
return self.dataset_length