mirror of
https://github.com/xiph/opus.git
synced 2025-05-16 16:38:30 +00:00
225 lines
9.6 KiB
Python
225 lines
9.6 KiB
Python
"""
|
|
/* Copyright (c) 2023 Amazon
|
|
Written by Jan Buethe */
|
|
/*
|
|
Redistribution and use in source and binary forms, with or without
|
|
modification, are permitted provided that the following conditions
|
|
are met:
|
|
|
|
- Redistributions of source code must retain the above copyright
|
|
notice, this list of conditions and the following disclaimer.
|
|
|
|
- Redistributions in binary form must reproduce the above copyright
|
|
notice, this list of conditions and the following disclaimer in the
|
|
documentation and/or other materials provided with the distribution.
|
|
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
|
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
|
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
|
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
|
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
|
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
|
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*/
|
|
"""
|
|
|
|
""" Dataset for LPCNet training """
|
|
import os
|
|
|
|
import yaml
|
|
import torch
|
|
import numpy as np
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
scale = 255.0/32768.0
|
|
scale_1 = 32768.0/255.0
|
|
def ulaw2lin(u):
|
|
u = u - 128
|
|
s = np.sign(u)
|
|
u = np.abs(u)
|
|
return s*scale_1*(np.exp(u/128.*np.log(256))-1)
|
|
|
|
|
|
def lin2ulaw(x):
|
|
s = np.sign(x)
|
|
x = np.abs(x)
|
|
u = (s*(128*np.log(1+scale*x)/np.log(256)))
|
|
u = np.clip(128 + np.round(u), 0, 255)
|
|
return u
|
|
|
|
|
|
def run_lpc(signal, lpcs, frame_length=160):
|
|
num_frames, lpc_order = lpcs.shape
|
|
|
|
prediction = np.concatenate(
|
|
[- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]
|
|
)
|
|
error = signal[lpc_order :] - prediction
|
|
|
|
return prediction, error
|
|
|
|
class LPCNetVocodingDataset(Dataset):
|
|
def __init__(self,
|
|
path_to_dataset,
|
|
features=['cepstrum', 'periods', 'pitch_corr'],
|
|
target='signal',
|
|
frames_per_sample=100,
|
|
feature_history=0,
|
|
feature_lookahead=0,
|
|
lpc_gamma=1):
|
|
|
|
super().__init__()
|
|
|
|
# load dataset info
|
|
self.path_to_dataset = path_to_dataset
|
|
with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f:
|
|
dataset = yaml.load(f, yaml.FullLoader)
|
|
|
|
# dataset version
|
|
self.version = dataset['version']
|
|
if self.version == 1:
|
|
self.getitem = self.getitem_v1
|
|
elif self.version == 2:
|
|
self.getitem = self.getitem_v2
|
|
else:
|
|
raise ValueError(f"dataset version {self.version} unknown")
|
|
|
|
# features
|
|
self.feature_history = feature_history
|
|
self.feature_lookahead = feature_lookahead
|
|
self.frame_offset = 2 + self.feature_history
|
|
self.frames_per_sample = frames_per_sample
|
|
self.input_features = features
|
|
self.feature_frame_layout = dataset['feature_frame_layout']
|
|
self.lpc_gamma = lpc_gamma
|
|
|
|
# load feature file
|
|
self.feature_file = os.path.join(path_to_dataset, dataset['feature_file'])
|
|
self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype'])
|
|
self.feature_frame_length = dataset['feature_frame_length']
|
|
|
|
assert len(self.features) % self.feature_frame_length == 0
|
|
self.features = self.features.reshape((-1, self.feature_frame_length))
|
|
|
|
# derive number of samples is dataset
|
|
self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1 - 2) // self.frames_per_sample
|
|
|
|
# signals
|
|
self.frame_length = dataset['frame_length']
|
|
self.signal_frame_layout = dataset['signal_frame_layout']
|
|
self.target = target
|
|
|
|
# load signals
|
|
self.signal_file = os.path.join(path_to_dataset, dataset['signal_file'])
|
|
self.signals = np.memmap(self.signal_file, dtype=dataset['signal_dtype'])
|
|
self.signal_frame_length = dataset['signal_frame_length']
|
|
self.signals = self.signals.reshape((-1, self.signal_frame_length))
|
|
assert len(self.signals) == len(self.features) * self.frame_length
|
|
|
|
|
|
def __getitem__(self, index):
|
|
return self.getitem(index)
|
|
|
|
def getitem_v2(self, index):
|
|
sample = dict()
|
|
|
|
# extract features
|
|
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
|
|
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
|
|
|
|
for feature in self.input_features:
|
|
feature_start, feature_stop = self.feature_frame_layout[feature]
|
|
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
|
|
|
|
# convert periods
|
|
if 'periods' in self.input_features:
|
|
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
|
|
|
|
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
|
|
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
|
|
|
|
# last_signal and signal are always expected to be there
|
|
sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
|
|
sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']]
|
|
|
|
# calculate prediction and error if lpc coefficients present and prediction not given
|
|
if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout:
|
|
# lpc coefficients with one frame lookahead
|
|
# frame positions (start one frame early for past excitation)
|
|
frame_start = self.frame_offset + self.frames_per_sample * index - 1
|
|
frame_stop = self.frame_offset + self.frames_per_sample * (index + 1)
|
|
|
|
# feature positions
|
|
lpc_start, lpc_stop = self.feature_frame_layout['lpc']
|
|
lpc_order = lpc_stop - lpc_start
|
|
lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop]
|
|
|
|
# LPC weighting
|
|
lpc_order = lpc_stop - lpc_start
|
|
weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)])
|
|
lpcs = lpcs * weights
|
|
|
|
# signal position (lpc_order samples as history)
|
|
signal_start = frame_start * self.frame_length - lpc_order + 1
|
|
signal_stop = frame_stop * self.frame_length + 1
|
|
noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
|
|
clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']]
|
|
|
|
noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length)
|
|
|
|
# extract signals
|
|
offset = self.frame_length
|
|
sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample]
|
|
sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample]
|
|
# calculate error between real signal and noisy prediction
|
|
|
|
|
|
sample['error'] = sample['signal'] - sample['prediction']
|
|
|
|
|
|
# concatenate features
|
|
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
|
|
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
|
|
target = torch.FloatTensor(sample[self.target]) / 2**15
|
|
periods = torch.LongTensor(sample['periods'])
|
|
|
|
return {'features' : features, 'periods' : periods, 'target' : target}
|
|
|
|
def getitem_v1(self, index):
|
|
sample = dict()
|
|
|
|
# extract features
|
|
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
|
|
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
|
|
|
|
for feature in self.input_features:
|
|
feature_start, feature_stop = self.feature_frame_layout[feature]
|
|
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
|
|
|
|
# convert periods
|
|
if 'periods' in self.input_features:
|
|
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
|
|
|
|
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
|
|
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
|
|
|
|
# last_signal and signal are always expected to be there
|
|
for signal_name, index in self.signal_frame_layout.items():
|
|
sample[signal_name] = self.signals[signal_start : signal_stop, index]
|
|
|
|
# concatenate features
|
|
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
|
|
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
|
|
signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1)
|
|
target = torch.LongTensor(sample[self.target])
|
|
periods = torch.LongTensor(sample['periods'])
|
|
|
|
return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
|
|
|
|
def __len__(self):
|
|
return self.dataset_length
|