added LPCNet torch implementation

Signed-off-by: Jan Buethe <jbuethe@amazon.de>
This commit is contained in:
Jan Buethe 2023-09-05 12:29:38 +02:00
parent 90a171c1c2
commit 35ee397e06
No known key found for this signature in database
GPG key ID: 9E32027A35B36314
38 changed files with 3200 additions and 0 deletions

View file

@ -0,0 +1,27 @@
# LPCNet
Incomplete pytorch implementation of LPCNet
## Data preparation
For data preparation use dump_data in github.com/xiph/LPCNet. To turn this into
a training dataset, copy data and feature file to a folder and run
python add_dataset_config.py my_dataset_folder
## Training
To train a model, create and adjust a setup file, e.g. with
python make_default_setup.py my_setup.yml --path2dataset my_dataset_folder
Then simply run
python train_lpcnet.py my_setup.yml my_output
## Inference
Create feature file with dump_data from github.com/xiph/LPCNet. Then run e.g.
python test_lpcnet.py features.f32 my_output/checkpoints/checkpoint_ep_10.pth out.wav
Inference runs on CPU and takes usually between 3 and 20 seconds per generated second of audio,
depending on the CPU.

View file

@ -0,0 +1,48 @@
import argparse
import os
import yaml
from utils.templates import dataset_template_v1, dataset_template_v2
parser = argparse.ArgumentParser("add_dataset_config.py")
parser.add_argument('path', type=str, help='path to folder containing feature and data file')
parser.add_argument('--version', type=int, help="dataset version, 1 for classic LPCNet with 55 feature slots, 2 for new format with 36 feature slots.", default=2)
parser.add_argument('--description', type=str, help='brief dataset description', default="I will add a description later")
args = parser.parse_args()
if args.version == 1:
template = dataset_template_v1
data_extension = '.u8'
elif args.version == 2:
template = dataset_template_v2
data_extension = '.s16'
else:
raise ValueError(f"unknown dataset version {args.version}")
# get folder content
content = os.listdir(args.path)
features = [c for c in content if c.endswith('.f32')]
if len(features) != 1:
print("could not determine feature file")
else:
template['feature_file'] = features[0]
data = [c for c in content if c.endswith(data_extension)]
if len(data) != 1:
print("could not determine data file")
else:
template['signal_file'] = data[0]
template['description'] = args.description
with open(os.path.join(args.path, 'info.yml'), 'w') as f:
yaml.dump(template, f)

View file

@ -0,0 +1 @@
from .lpcnet_dataset import LPCNetDataset

View file

@ -0,0 +1,198 @@
""" Dataset for LPCNet training """
import os
import yaml
import torch
import numpy as np
from torch.utils.data import Dataset
scale = 255.0/32768.0
scale_1 = 32768.0/255.0
def ulaw2lin(u):
u = u - 128
s = np.sign(u)
u = np.abs(u)
return s*scale_1*(np.exp(u/128.*np.log(256))-1)
def lin2ulaw(x):
s = np.sign(x)
x = np.abs(x)
u = (s*(128*np.log(1+scale*x)/np.log(256)))
u = np.clip(128 + np.round(u), 0, 255)
return u
def run_lpc(signal, lpcs, frame_length=160):
num_frames, lpc_order = lpcs.shape
prediction = np.concatenate(
[- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]
)
error = signal[lpc_order :] - prediction
return prediction, error
class LPCNetDataset(Dataset):
def __init__(self,
path_to_dataset,
features=['cepstrum', 'periods', 'pitch_corr'],
input_signals=['last_signal', 'prediction', 'last_error'],
target='error',
frames_per_sample=15,
feature_history=2,
feature_lookahead=2,
lpc_gamma=1):
super(LPCNetDataset, self).__init__()
# load dataset info
self.path_to_dataset = path_to_dataset
with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f:
dataset = yaml.load(f, yaml.FullLoader)
# dataset version
self.version = dataset['version']
if self.version == 1:
self.getitem = self.getitem_v1
elif self.version == 2:
self.getitem = self.getitem_v2
else:
raise ValueError(f"dataset version {self.version} unknown")
# features
self.feature_history = feature_history
self.feature_lookahead = feature_lookahead
self.frame_offset = 1 + self.feature_history
self.frames_per_sample = frames_per_sample
self.input_features = features
self.feature_frame_layout = dataset['feature_frame_layout']
self.lpc_gamma = lpc_gamma
# load feature file
self.feature_file = os.path.join(path_to_dataset, dataset['feature_file'])
self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype'])
self.feature_frame_length = dataset['feature_frame_length']
assert len(self.features) % self.feature_frame_length == 0
self.features = self.features.reshape((-1, self.feature_frame_length))
# derive number of samples is dataset
self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1) // self.frames_per_sample
# signals
self.frame_length = dataset['frame_length']
self.signal_frame_layout = dataset['signal_frame_layout']
self.input_signals = input_signals
self.target = target
# load signals
self.signal_file = os.path.join(path_to_dataset, dataset['signal_file'])
self.signals = np.memmap(self.signal_file, dtype=dataset['signal_dtype'])
self.signal_frame_length = dataset['signal_frame_length']
self.signals = self.signals.reshape((-1, self.signal_frame_length))
assert len(self.signals) == len(self.features) * self.frame_length
def __getitem__(self, index):
return self.getitem(index)
def getitem_v2(self, index):
sample = dict()
# extract features
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
for feature in self.input_features:
feature_start, feature_stop = self.feature_frame_layout[feature]
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
# convert periods
if 'periods' in self.input_features:
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
# last_signal and signal are always expected to be there
sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']]
# calculate prediction and error if lpc coefficients present and prediction not given
if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout:
# lpc coefficients with one frame lookahead
# frame positions (start one frame early for past excitation)
frame_start = self.frame_offset + self.frames_per_sample * index - 1
frame_stop = self.frame_offset + self.frames_per_sample * (index + 1)
# feature positions
lpc_start, lpc_stop = self.feature_frame_layout['lpc']
lpc_order = lpc_stop - lpc_start
lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop]
# LPC weighting
lpc_order = lpc_stop - lpc_start
weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)])
lpcs = lpcs * weights
# signal position (lpc_order samples as history)
signal_start = frame_start * self.frame_length - lpc_order + 1
signal_stop = frame_stop * self.frame_length + 1
noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']]
noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length)
# extract signals
offset = self.frame_length
sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample]
sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample]
# calculate error between real signal and noisy prediction
sample['error'] = sample['signal'] - sample['prediction']
# concatenate features
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
signals = torch.cat([torch.LongTensor(lin2ulaw(sample[key])).unsqueeze(-1) for key in self.input_signals], dim=-1)
target = torch.LongTensor(lin2ulaw(sample[self.target]))
periods = torch.LongTensor(sample['periods'])
return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
def getitem_v1(self, index):
sample = dict()
# extract features
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
for feature in self.input_features:
feature_start, feature_stop = self.feature_frame_layout[feature]
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
# convert periods
if 'periods' in self.input_features:
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
# last_signal and signal are always expected to be there
for signal_name, index in self.signal_frame_layout.items():
sample[signal_name] = self.signals[signal_start : signal_stop, index]
# concatenate features
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1)
target = torch.LongTensor(sample[self.target])
periods = torch.LongTensor(sample['periods'])
return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
def __len__(self):
return self.dataset_length

View file

@ -0,0 +1,112 @@
import torch
from tqdm import tqdm
import sys
def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
model.to(device)
model.train()
running_loss = 0
previous_running_loss = 0
# gru states
gru_a_state = torch.zeros(1, dataloader.batch_size, model.gru_a_units, device=device).to(device)
gru_b_state = torch.zeros(1, dataloader.batch_size, model.gru_b_units, device=device).to(device)
gru_states = [gru_a_state, gru_b_state]
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# set gradients to zero
optimizer.zero_grad()
# zero out initial gru states
gru_a_state.zero_()
gru_b_state.zero_()
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target']
# calculate model output
output = model(batch['features'], batch['periods'], batch['signals'], gru_states)
# calculate loss
loss = criterion(output.permute(0, 2, 1), target)
# calculate gradients
loss.backward()
# update weights
optimizer.step()
# update learning rate
scheduler.step()
# call sparsifier
model.sparsify()
# update running loss
running_loss += float(loss.cpu())
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
previous_running_loss = running_loss
running_loss /= len(dataloader)
return running_loss
def evaluate(model, criterion, dataloader, device, log_interval=10):
model.to(device)
model.eval()
running_loss = 0
previous_running_loss = 0
# gru states
gru_a_state = torch.zeros(1, dataloader.batch_size, model.gru_a_units, device=device).to(device)
gru_b_state = torch.zeros(1, dataloader.batch_size, model.gru_b_units, device=device).to(device)
gru_states = [gru_a_state, gru_b_state]
with torch.no_grad():
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# zero out initial gru states
gru_a_state.zero_()
gru_b_state.zero_()
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target']
# calculate model output
output = model(batch['features'], batch['periods'], batch['signals'], gru_states)
# calculate loss
loss = criterion(output.permute(0, 2, 1), target)
# update running loss
running_loss += float(loss.cpu())
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
previous_running_loss = running_loss
running_loss /= len(dataloader)
return running_loss

View file

@ -0,0 +1,27 @@
import argparse
import yaml
from utils.templates import setup_dict
parser = argparse.ArgumentParser()
parser.add_argument('name', type=str, help='name of default setup file')
parser.add_argument('--model', choices=['lpcnet', 'multi_rate'], help='LPCNet model name', default='lpcnet')
parser.add_argument('--path2dataset', type=str, help='dataset path', default=None)
args = parser.parse_args()
setup = setup_dict[args.model]
# update dataset if given
if type(args.path2dataset) != type(None):
setup['dataset'] = args.path2dataset
name = args.name
if not name.endswith('.yml'):
name += '.yml'
if __name__ == '__main__':
with open(name, 'w') as f:
f.write(yaml.dump(setup))

View file

@ -0,0 +1,49 @@
import argparse
import os
import sys
parser = argparse.ArgumentParser()
parser.add_argument("config_name", type=str, help="name of config file (.yml will be appended)")
parser.add_argument("test_name", type=str, help="name for test result display")
parser.add_argument("checkpoint", type=str, help="checkpoint to test")
parser.add_argument("--lpcnet-demo", type=str, help="path to lpcnet_demo binary, default: /local/code/LPCNet/lpcnet_demo", default="/local/code/LPCNet/lpcnet_demo")
parser.add_argument("--lpcnext-path", type=str, help="path to lpcnext folder, defalut: dirname(__file__)", default=os.path.dirname(__file__))
parser.add_argument("--python-exe", type=str, help='python executable path, default: sys.executable', default=sys.executable)
parser.add_argument("--pad", type=str, help="left pad of output in seconds, default: 0.015", default="0.015")
parser.add_argument("--trim", type=str, help="left trim of output in seconds, default: 0", default="0")
template='''
test: "{NAME}"
processing:
- "sox {{INPUT}} {{INPUT}}.raw"
- "{LPCNET_DEMO} -features {{INPUT}}.raw {{INPUT}}.features.f32"
- "{PYTHON} {WORKING}/test_lpcnet.py {{INPUT}}.features.f32 {CHECKPOINT} {{OUTPUT}}.ua.wav"
- "sox {{OUTPUT}}.ua.wav {{OUTPUT}}.uap.wav pad {PAD}"
- "sox {{OUTPUT}}.uap.wav {{OUTPUT}} trim {TRIM}"
- "rm {{INPUT}}.raw {{OUTPUT}}.uap.wav {{OUTPUT}}.ua.wav {{INPUT}}.features.f32"
'''
if __name__ == "__main__":
args = parser.parse_args()
file_content = template.format(
NAME=args.test_name,
LPCNET_DEMO=os.path.abspath(args.lpcnet_demo),
PYTHON=os.path.abspath(args.python_exe),
PAD=args.pad,
TRIM=args.trim,
WORKING=os.path.abspath(args.lpcnext_path),
CHECKPOINT=os.path.abspath(args.checkpoint)
)
print(file_content)
filename = args.config_name
if not filename.endswith(".yml"):
filename += ".yml"
with open(filename, "w") as f:
f.write(file_content)

View file

@ -0,0 +1,8 @@
from .lpcnet import LPCNet
from .multi_rate_lpcnet import MultiRateLPCNet
model_dict = {
'lpcnet' : LPCNet,
'multi_rate' : MultiRateLPCNet
}

View file

@ -0,0 +1,274 @@
import torch
from torch import nn
import numpy as np
from utils.ulaw import lin2ulawq, ulaw2lin
from utils.sample import sample_excitation
from utils.pcm import clip_to_int16
from utils.sparsification import GRUSparsifier, calculate_gru_flops_per_step
from utils.layers import DualFC
from utils.misc import get_pdf_from_tree
class LPCNet(nn.Module):
def __init__(self, config):
super(LPCNet, self).__init__()
#
self.input_layout = config['input_layout']
self.feature_history = config['feature_history']
self.feature_lookahead = config['feature_lookahead']
# frame rate network parameters
self.feature_dimension = config['feature_dimension']
self.period_embedding_dim = config['period_embedding_dim']
self.period_levels = config['period_levels']
self.feature_channels = self.feature_dimension + self.period_embedding_dim
self.feature_conditioning_dim = config['feature_conditioning_dim']
self.feature_conv_kernel_size = config['feature_conv_kernel_size']
# frame rate network layers
self.period_embedding = nn.Embedding(self.period_levels, self.period_embedding_dim)
self.feature_conv1 = nn.Conv1d(self.feature_channels, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
self.feature_conv2 = nn.Conv1d(self.feature_conditioning_dim, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
self.feature_dense1 = nn.Linear(self.feature_conditioning_dim, self.feature_conditioning_dim)
self.feature_dense2 = nn.Linear(*(2*[self.feature_conditioning_dim]))
# sample rate network parameters
self.frame_size = config['frame_size']
self.signal_levels = config['signal_levels']
self.signal_embedding_dim = config['signal_embedding_dim']
self.gru_a_units = config['gru_a_units']
self.gru_b_units = config['gru_b_units']
self.output_levels = config['output_levels']
self.hsampling = config.get('hsampling', False)
self.gru_a_input_dim = len(self.input_layout['signals']) * self.signal_embedding_dim + self.feature_conditioning_dim
self.gru_b_input_dim = self.gru_a_units + self.feature_conditioning_dim
# sample rate network layers
self.signal_embedding = nn.Embedding(self.signal_levels, self.signal_embedding_dim)
self.gru_a = nn.GRU(self.gru_a_input_dim, self.gru_a_units, batch_first=True)
self.gru_b = nn.GRU(self.gru_b_input_dim, self.gru_b_units, batch_first=True)
self.dual_fc = DualFC(self.gru_b_units, self.output_levels)
# sparsification
self.sparsifier = []
# GRU A
if 'gru_a' in config['sparsification']:
gru_config = config['sparsification']['gru_a']
task_list = [(self.gru_a, gru_config['params'])]
self.sparsifier.append(GRUSparsifier(task_list,
gru_config['start'],
gru_config['stop'],
gru_config['interval'],
gru_config['exponent'])
)
self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a,
gru_config['params'], drop_input=True)
else:
self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, drop_input=True)
# GRU B
if 'gru_b' in config['sparsification']:
gru_config = config['sparsification']['gru_b']
task_list = [(self.gru_b, gru_config['params'])]
self.sparsifier.append(GRUSparsifier(task_list,
gru_config['start'],
gru_config['stop'],
gru_config['interval'],
gru_config['exponent'])
)
self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b,
gru_config['params'])
else:
self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b)
# inference parameters
self.lpc_gamma = config.get('lpc_gamma', 1)
def sparsify(self):
for sparsifier in self.sparsifier:
sparsifier.step()
def get_gflops(self, fs, verbose=False):
gflops = 0
# frame rate network
conditioning_dim = self.feature_conditioning_dim
feature_channels = self.feature_channels
frame_rate = fs / self.frame_size
frame_rate_network_complexity = 1e-9 * 2 * (5 * conditioning_dim + 3 * feature_channels) * conditioning_dim * frame_rate
if verbose:
print(f"frame rate network: {frame_rate_network_complexity} GFLOPS")
gflops += frame_rate_network_complexity
# gru a
gru_a_rate = fs
gru_a_complexity = 1e-9 * gru_a_rate * self.gru_a_flops_per_step
if verbose:
print(f"gru A: {gru_a_complexity} GFLOPS")
gflops += gru_a_complexity
# gru b
gru_b_rate = fs
gru_b_complexity = 1e-9 * gru_b_rate * self.gru_b_flops_per_step
if verbose:
print(f"gru B: {gru_b_complexity} GFLOPS")
gflops += gru_b_complexity
# dual fcs
fc = self.dual_fc
rate = fs
input_size = fc.dense1.in_features
output_size = fc.dense1.out_features
dual_fc_complexity = 1e-9 * (4 * input_size * output_size + 22 * output_size) * rate
if self.hsampling:
dual_fc_complexity /= 8
if verbose:
print(f"dual_fc: {dual_fc_complexity} GFLOPS")
gflops += dual_fc_complexity
if verbose:
print(f'total: {gflops} GFLOPS')
return gflops
def frame_rate_network(self, features, periods):
embedded_periods = torch.flatten(self.period_embedding(periods), 2, 3)
features = torch.concat((features, embedded_periods), dim=-1)
# convert to channels first and calculate conditioning vector
c = torch.permute(features, [0, 2, 1])
c = torch.tanh(self.feature_conv1(c))
c = torch.tanh(self.feature_conv2(c))
# back to channels last
c = torch.permute(c, [0, 2, 1])
c = torch.tanh(self.feature_dense1(c))
c = torch.tanh(self.feature_dense2(c))
return c
def sample_rate_network(self, signals, c, gru_states):
embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
c_upsampled = torch.repeat_interleave(c, self.frame_size, dim=1)
y = torch.concat((embedded_signals, c_upsampled), dim=-1)
y, gru_a_state = self.gru_a(y, gru_states[0])
y = torch.concat((y, c_upsampled), dim=-1)
y, gru_b_state = self.gru_b(y, gru_states[1])
y = self.dual_fc(y)
if self.hsampling:
y = torch.sigmoid(y)
log_probs = torch.log(get_pdf_from_tree(y) + 1e-6)
else:
log_probs = torch.log_softmax(y, dim=-1)
return log_probs, (gru_a_state, gru_b_state)
def decoder(self, signals, c, gru_states):
embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
y = torch.concat((embedded_signals, c), dim=-1)
y, gru_a_state = self.gru_a(y, gru_states[0])
y = torch.concat((y, c), dim=-1)
y, gru_b_state = self.gru_b(y, gru_states[1])
y = self.dual_fc(y)
if self.hsampling:
y = torch.sigmoid(y)
probs = get_pdf_from_tree(y)
else:
probs = torch.softmax(y, dim=-1)
return probs, (gru_a_state, gru_b_state)
def forward(self, features, periods, signals, gru_states):
c = self.frame_rate_network(features, periods)
log_probs, _ = self.sample_rate_network(signals, c, gru_states)
return log_probs
def generate(self, features, periods, lpcs):
with torch.no_grad():
device = self.parameters().__next__().device
num_frames = features.shape[0] - self.feature_history - self.feature_lookahead
lpc_order = lpcs.shape[-1]
num_input_signals = len(self.input_layout['signals'])
pitch_corr_position = self.input_layout['features']['pitch_corr'][0]
# signal buffers
pcm = torch.zeros((num_frames * self.frame_size + lpc_order))
output = torch.zeros((num_frames * self.frame_size), dtype=torch.int16)
mem = 0
# state buffers
gru_a_state = torch.zeros((1, 1, self.gru_a_units))
gru_b_state = torch.zeros((1, 1, self.gru_b_units))
gru_states = [gru_a_state, gru_b_state]
input_signals = torch.zeros((1, 1, num_input_signals), dtype=torch.long) + 128
# push data to device
features = features.to(device)
periods = periods.to(device)
lpcs = lpcs.to(device)
# lpc weighting
weights = torch.FloatTensor([self.lpc_gamma ** (i + 1) for i in range(lpc_order)]).to(device)
lpcs = lpcs * weights
# run feature encoding
c = self.frame_rate_network(features.unsqueeze(0), periods.unsqueeze(0))
for frame_index in range(num_frames):
frame_start = frame_index * self.frame_size
pitch_corr = features[frame_index + self.feature_history, pitch_corr_position]
a = - torch.flip(lpcs[frame_index + self.feature_history], [0])
current_c = c[:, frame_index : frame_index + 1, :]
for i in range(self.frame_size):
pcm_position = frame_start + i + lpc_order
output_position = frame_start + i
# prepare input
pred = torch.sum(pcm[pcm_position - lpc_order : pcm_position] * a)
if 'prediction' in self.input_layout['signals']:
input_signals[0, 0, self.input_layout['signals']['prediction']] = lin2ulawq(pred)
# run single step of sample rate network
probs, gru_states = self.decoder(
input_signals,
current_c,
gru_states
)
# sample from output
exc_ulaw = sample_excitation(probs, pitch_corr)
# signal generation
exc = ulaw2lin(exc_ulaw)
sig = exc + pred
pcm[pcm_position] = sig
mem = 0.85 * mem + float(sig)
output[output_position] = clip_to_int16(round(mem))
# buffer update
if 'last_signal' in self.input_layout['signals']:
input_signals[0, 0, self.input_layout['signals']['last_signal']] = lin2ulawq(sig)
if 'last_error' in self.input_layout['signals']:
input_signals[0, 0, self.input_layout['signals']['last_error']] = lin2ulawq(exc)
return output

View file

@ -0,0 +1,408 @@
import torch
from torch import nn
from utils.layers.subconditioner import get_subconditioner
from utils.layers import DualFC
from utils.ulaw import lin2ulawq, ulaw2lin
from utils.sample import sample_excitation
from utils.pcm import clip_to_int16
from utils.sparsification import GRUSparsifier, calculate_gru_flops_per_step
from utils.misc import interleave_tensors
# MultiRateLPCNet
class MultiRateLPCNet(nn.Module):
def __init__(self, config):
super(MultiRateLPCNet, self).__init__()
# general parameters
self.input_layout = config['input_layout']
self.feature_history = config['feature_history']
self.feature_lookahead = config['feature_lookahead']
self.signals = config['signals']
# frame rate network parameters
self.feature_dimension = config['feature_dimension']
self.period_embedding_dim = config['period_embedding_dim']
self.period_levels = config['period_levels']
self.feature_channels = self.feature_dimension + self.period_embedding_dim
self.feature_conditioning_dim = config['feature_conditioning_dim']
self.feature_conv_kernel_size = config['feature_conv_kernel_size']
# frame rate network layers
self.period_embedding = nn.Embedding(self.period_levels, self.period_embedding_dim)
self.feature_conv1 = nn.Conv1d(self.feature_channels, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
self.feature_conv2 = nn.Conv1d(self.feature_conditioning_dim, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
self.feature_dense1 = nn.Linear(self.feature_conditioning_dim, self.feature_conditioning_dim)
self.feature_dense2 = nn.Linear(*(2*[self.feature_conditioning_dim]))
# sample rate network parameters
self.frame_size = config['frame_size']
self.signal_levels = config['signal_levels']
self.signal_embedding_dim = config['signal_embedding_dim']
self.gru_a_units = config['gru_a_units']
self.gru_b_units = config['gru_b_units']
self.output_levels = config['output_levels']
# subconditioning B
sub_config = config['subconditioning']['subconditioning_b']
self.substeps_b = sub_config['number_of_subsamples']
self.subcondition_signals_b = sub_config['signals']
self.signals_idx_b = [self.input_layout['signals'][key] for key in sub_config['signals']]
method = sub_config['method']
kwargs = sub_config['kwargs']
if type(kwargs) == type(None):
kwargs = dict()
state_size = self.gru_b_units
self.subconditioner_b = get_subconditioner(method,
sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'],
state_size, self.signal_levels, len(sub_config['signals']),
**sub_config['kwargs'])
# subconditioning A
sub_config = config['subconditioning']['subconditioning_a']
self.substeps_a = sub_config['number_of_subsamples']
self.subcondition_signals_a = sub_config['signals']
self.signals_idx_a = [self.input_layout['signals'][key] for key in sub_config['signals']]
method = sub_config['method']
kwargs = sub_config['kwargs']
if type(kwargs) == type(None):
kwargs = dict()
state_size = self.gru_a_units
self.subconditioner_a = get_subconditioner(method,
sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'],
state_size, self.signal_levels, self.substeps_b * len(sub_config['signals']),
**sub_config['kwargs'])
# wrap up subconditioning, group_size_gru_a holds the number
# of timesteps that are grouped as sample input for GRU A
# input and group_size_subcondition_a holds the number of samples that are
# grouped as input to pre-GRU B subconditioning
self.group_size_gru_a = self.substeps_a * self.substeps_b
self.group_size_subcondition_a = self.substeps_b
self.gru_a_rate_divider = self.group_size_gru_a
self.gru_b_rate_divider = self.substeps_b
# gru sizes
self.gru_a_input_dim = self.group_size_gru_a * len(self.signals) * self.signal_embedding_dim + self.feature_conditioning_dim
self.gru_b_input_dim = self.subconditioner_a.get_output_dim(0) + self.feature_conditioning_dim
self.signals_idx = [self.input_layout['signals'][key] for key in self.signals]
# sample rate network layers
self.signal_embedding = nn.Embedding(self.signal_levels, self.signal_embedding_dim)
self.gru_a = nn.GRU(self.gru_a_input_dim, self.gru_a_units, batch_first=True)
self.gru_b = nn.GRU(self.gru_b_input_dim, self.gru_b_units, batch_first=True)
# sparsification
self.sparsifier = []
# GRU A
if 'gru_a' in config['sparsification']:
gru_config = config['sparsification']['gru_a']
task_list = [(self.gru_a, gru_config['params'])]
self.sparsifier.append(GRUSparsifier(task_list,
gru_config['start'],
gru_config['stop'],
gru_config['interval'],
gru_config['exponent'])
)
self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a,
gru_config['params'], drop_input=True)
else:
self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, drop_input=True)
# GRU B
if 'gru_b' in config['sparsification']:
gru_config = config['sparsification']['gru_b']
task_list = [(self.gru_b, gru_config['params'])]
self.sparsifier.append(GRUSparsifier(task_list,
gru_config['start'],
gru_config['stop'],
gru_config['interval'],
gru_config['exponent'])
)
self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b,
gru_config['params'])
else:
self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b)
# dual FCs
self.dual_fc = []
for i in range(self.substeps_b):
dim = self.subconditioner_b.get_output_dim(i)
self.dual_fc.append(DualFC(dim, self.output_levels))
self.add_module(f"dual_fc_{i}", self.dual_fc[-1])
def get_gflops(self, fs, verbose=False, hierarchical_sampling=False):
gflops = 0
# frame rate network
conditioning_dim = self.feature_conditioning_dim
feature_channels = self.feature_channels
frame_rate = fs / self.frame_size
frame_rate_network_complexity = 1e-9 * 2 * (5 * conditioning_dim + 3 * feature_channels) * conditioning_dim * frame_rate
if verbose:
print(f"frame rate network: {frame_rate_network_complexity} GFLOPS")
gflops += frame_rate_network_complexity
# gru a
gru_a_rate = fs / self.group_size_gru_a
gru_a_complexity = 1e-9 * gru_a_rate * self.gru_a_flops_per_step
if verbose:
print(f"gru A: {gru_a_complexity} GFLOPS")
gflops += gru_a_complexity
# subconditioning a
subcond_a_rate = fs / self.substeps_b
subconditioning_a_complexity = 1e-9 * self.subconditioner_a.get_average_flops_per_step() * subcond_a_rate
if verbose:
print(f"subconditioning A: {subconditioning_a_complexity} GFLOPS")
gflops += subconditioning_a_complexity
# gru b
gru_b_rate = fs / self.substeps_b
gru_b_complexity = 1e-9 * gru_b_rate * self.gru_b_flops_per_step
if verbose:
print(f"gru B: {gru_b_complexity} GFLOPS")
gflops += gru_b_complexity
# subconditioning b
subcond_b_rate = fs
subconditioning_b_complexity = 1e-9 * self.subconditioner_b.get_average_flops_per_step() * subcond_b_rate
if verbose:
print(f"subconditioning B: {subconditioning_b_complexity} GFLOPS")
gflops += subconditioning_b_complexity
# dual fcs
for i, fc in enumerate(self.dual_fc):
rate = fs / len(self.dual_fc)
input_size = fc.dense1.in_features
output_size = fc.dense1.out_features
dual_fc_complexity = 1e-9 * (4 * input_size * output_size + 22 * output_size) * rate
if hierarchical_sampling:
dual_fc_complexity /= 8
if verbose:
print(f"dual_fc_{i}: {dual_fc_complexity} GFLOPS")
gflops += dual_fc_complexity
if verbose:
print(f'total: {gflops} GFLOPS')
return gflops
def sparsify(self):
for sparsifier in self.sparsifier:
sparsifier.step()
def frame_rate_network(self, features, periods):
embedded_periods = torch.flatten(self.period_embedding(periods), 2, 3)
features = torch.concat((features, embedded_periods), dim=-1)
# convert to channels first and calculate conditioning vector
c = torch.permute(features, [0, 2, 1])
c = torch.tanh(self.feature_conv1(c))
c = torch.tanh(self.feature_conv2(c))
# back to channels last
c = torch.permute(c, [0, 2, 1])
c = torch.tanh(self.feature_dense1(c))
c = torch.tanh(self.feature_dense2(c))
return c
def prepare_signals(self, signals, group_size, signal_idx):
""" extracts, delays and groups signals """
batch_size, sequence_length, num_signals = signals.shape
# extract signals according to position
signals = torch.cat([signals[:, :, i : i + 1] for i in signal_idx],
dim=-1)
# roll back pcm to account for grouping
signals = torch.roll(signals, group_size - 1, -2)
# reshape
signals = torch.reshape(signals,
(batch_size, sequence_length // group_size, group_size * len(signal_idx)))
return signals
def sample_rate_network(self, signals, c, gru_states):
signals_a = self.prepare_signals(signals, self.group_size_gru_a, self.signals_idx)
embedded_signals = torch.flatten(self.signal_embedding(signals_a), 2, 3)
# features at GRU A rate
c_upsampled_a = torch.repeat_interleave(c, self.frame_size // self.gru_a_rate_divider, dim=1)
# features at GRU B rate
c_upsampled_b = torch.repeat_interleave(c, self.frame_size // self.gru_b_rate_divider, dim=1)
y = torch.concat((embedded_signals, c_upsampled_a), dim=-1)
y, gru_a_state = self.gru_a(y, gru_states[0])
# first round of upsampling and subconditioning
c_signals_a = self.prepare_signals(signals, self.group_size_subcondition_a, self.signals_idx_a)
y = self.subconditioner_a(y, c_signals_a)
y = interleave_tensors(y)
y = torch.concat((y, c_upsampled_b), dim=-1)
y, gru_b_state = self.gru_b(y, gru_states[1])
c_signals_b = self.prepare_signals(signals, 1, self.signals_idx_b)
y = self.subconditioner_b(y, c_signals_b)
y = [self.dual_fc[i](y[i]) for i in range(self.substeps_b)]
y = interleave_tensors(y)
return y, (gru_a_state, gru_b_state)
def decoder(self, signals, c, gru_states):
embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
y = torch.concat((embedded_signals, c), dim=-1)
y, gru_a_state = self.gru_a(y, gru_states[0])
y = torch.concat((y, c), dim=-1)
y, gru_b_state = self.gru_b(y, gru_states[1])
y = self.dual_fc(y)
return torch.softmax(y, dim=-1), (gru_a_state, gru_b_state)
def forward(self, features, periods, signals, gru_states):
c = self.frame_rate_network(features, periods)
y, _ = self.sample_rate_network(signals, c, gru_states)
log_probs = torch.log_softmax(y, dim=-1)
return log_probs
def generate(self, features, periods, lpcs):
with torch.no_grad():
device = self.parameters().__next__().device
num_frames = features.shape[0] - self.feature_history - self.feature_lookahead
lpc_order = lpcs.shape[-1]
num_input_signals = len(self.signals)
pitch_corr_position = self.input_layout['features']['pitch_corr'][0]
# signal buffers
last_signal = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
prediction = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
last_error = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
output = torch.zeros((num_frames * self.frame_size), dtype=torch.int16)
mem = 0
# state buffers
gru_a_state = torch.zeros((1, 1, self.gru_a_units))
gru_b_state = torch.zeros((1, 1, self.gru_b_units))
input_signals = 128 + torch.zeros(self.group_size_gru_a * num_input_signals, dtype=torch.long)
# conditioning signals for subconditioner a
c_signals_a = 128 + torch.zeros(self.group_size_subcondition_a * len(self.signals_idx_a), dtype=torch.long)
# conditioning signals for subconditioner b
c_signals_b = 128 + torch.zeros(len(self.signals_idx_b), dtype=torch.long)
# signal dict
signal_dict = {
'prediction' : prediction,
'last_error' : last_error,
'last_signal' : last_signal
}
# push data to device
features = features.to(device)
periods = periods.to(device)
lpcs = lpcs.to(device)
# run feature encoding
c = self.frame_rate_network(features.unsqueeze(0), periods.unsqueeze(0))
for frame_index in range(num_frames):
frame_start = frame_index * self.frame_size
pitch_corr = features[frame_index + self.feature_history, pitch_corr_position]
a = - torch.flip(lpcs[frame_index + self.feature_history], [0])
current_c = c[:, frame_index : frame_index + 1, :]
for i in range(0, self.frame_size, self.group_size_gru_a):
pcm_position = frame_start + i + lpc_order
output_position = frame_start + i
# calculate newest prediction
prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a)
# prepare input
for slot in range(self.group_size_gru_a):
k = slot - self.group_size_gru_a + 1
for idx, name in enumerate(self.signals):
input_signals[idx + slot * num_input_signals] = lin2ulawq(
signal_dict[name][pcm_position + k]
)
# run GRU A
embed_signals = self.signal_embedding(input_signals.reshape((1, 1, -1)))
embed_signals = torch.flatten(embed_signals, 2)
y = torch.cat((embed_signals, current_c), dim=-1)
h_a, gru_a_state = self.gru_a(y, gru_a_state)
# loop over substeps_a
for step_a in range(self.substeps_a):
# prepare conditioning input
for slot in range(self.group_size_subcondition_a):
k = slot - self.group_size_subcondition_a + 1
for idx, name in enumerate(self.subcondition_signals_a):
c_signals_a[idx + slot * num_input_signals] = lin2ulawq(
signal_dict[name][pcm_position + k]
)
# subconditioning
h_a = self.subconditioner_a.single_step(step_a, h_a, c_signals_a.reshape((1, 1, -1)))
# run GRU B
y = torch.cat((h_a, current_c), dim=-1)
h_b, gru_b_state = self.gru_b(y, gru_b_state)
# loop over substeps b
for step_b in range(self.substeps_b):
# prepare subconditioning input
for idx, name in enumerate(self.subcondition_signals_b):
c_signals_b[idx] = lin2ulawq(
signal_dict[name][pcm_position]
)
# subcondition
h_b = self.subconditioner_b.single_step(step_b, h_b, c_signals_b.reshape((1, 1, -1)))
# run dual FC
probs = torch.softmax(self.dual_fc[step_b](h_b), dim=-1)
# sample
new_exc = ulaw2lin(sample_excitation(probs, pitch_corr))
# update signals
sig = new_exc + prediction[pcm_position]
last_error[pcm_position + 1] = new_exc
last_signal[pcm_position + 1] = sig
mem = 0.85 * mem + float(sig)
output[output_position] = clip_to_int16(round(mem))
# increase positions
pcm_position += 1
output_position += 1
# calculate next prediction
prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a)
return output

View file

@ -0,0 +1,35 @@
import argparse
import yaml
from models import model_dict
debug = False
if debug:
args = type('dummy', (object,),
{
'setup' : 'setups/lpcnet_m/setup_1_4_concatenative.yml',
'hierarchical_sampling' : False
})()
else:
parser = argparse.ArgumentParser()
parser.add_argument('setup', type=str, help='setup yaml file')
parser.add_argument('--hierarchical-sampling', action="store_true", help='whether to assume hierarchical sampling (default=False)', default=False)
args = parser.parse_args()
with open(args.setup, 'r') as f:
setup = yaml.load(f.read(), yaml.FullLoader)
# check model
if not 'model' in setup['lpcnet']:
print(f'warning: did not find model entry in setup, using default lpcnet')
model_name = 'lpcnet'
else:
model_name = setup['lpcnet']['model']
# create model
model = model_dict[model_name](setup['lpcnet']['config'])
gflops = model.get_gflops(16000, verbose=True, hierarchical_sampling=args.hierarchical_sampling)

View file

@ -0,0 +1,161 @@
import argparse
import os
from uuid import UUID
from collections import OrderedDict
import pickle
import torch
import numpy as np
import utils
parser = argparse.ArgumentParser()
parser.add_argument("input", type=str, help="input folder containing multi-run output")
parser.add_argument("tag", type=str, help="tag for multi-run experiment")
parser.add_argument("csv", type=str, help="name for output csv")
def is_uuid(val):
try:
UUID(val)
return True
except:
return False
def collect_results(folder):
training_folder = os.path.join(folder, 'training')
testing_folder = os.path.join(folder, 'testing')
# validation loss
checkpoint = torch.load(os.path.join(training_folder, 'checkpoints', 'checkpoint_finalize_epoch_1.pth'), map_location='cpu')
validation_loss = checkpoint['validation_loss']
# eval_warpq
eval_warpq = utils.data.parse_warpq_scores(os.path.join(training_folder, 'out_finalize.txt'))[-1]
# testing results
testing_results = utils.data.collect_test_stats(os.path.join(testing_folder, 'final'))
results = OrderedDict()
results['eval_loss'] = validation_loss
results['eval_warpq'] = eval_warpq
results['pesq_mean'] = testing_results['pesq'][0]
results['warpq_mean'] = testing_results['warpq'][0]
results['pitch_error_mean'] = testing_results['pitch_error'][0]
results['voicing_error_mean'] = testing_results['voicing_error'][0]
return results
def print_csv(path, results, tag, ranks=None, header=True):
metrics = next(iter(results.values())).keys()
if ranks is not None:
rank_keys = next(iter(ranks.values())).keys()
else:
rank_keys = []
with open(path, 'w') as f:
if header:
f.write("uuid, tag")
for metric in metrics:
f.write(f", {metric}")
for rank in rank_keys:
f.write(f", {rank}")
f.write("\n")
for uuid, values in results.items():
f.write(f"{uuid}, {tag}")
for val in values.values():
f.write(f", {val:10.8f}")
for rank in rank_keys:
f.write(f", {ranks[uuid][rank]:4d}")
f.write("\n")
def get_ranks(results):
metrics = list(next(iter(results.values())).keys())
positive = {'pesq_mean', 'mix'}
ranks = OrderedDict()
for key in results.keys():
ranks[key] = OrderedDict()
for metric in metrics:
sign = -1 if metric in positive else 1
x = sorted([(key, value[metric]) for key, value in results.items()], key=lambda x: sign * x[1])
x = [y[0] for y in x]
for key in results.keys():
ranks[key]['rank_' + metric] = x.index(key) + 1
return ranks
def analyse_metrics(results):
metrics = ['eval_loss', 'pesq_mean', 'warpq_mean', 'pitch_error_mean', 'voicing_error_mean']
x = []
for metric in metrics:
x.append([val[metric] for val in results.values()])
x = np.array(x)
print(x)
def add_mix_metric(results):
metrics = ['eval_loss', 'pesq_mean', 'warpq_mean', 'pitch_error_mean', 'voicing_error_mean']
x = []
for metric in metrics:
x.append([val[metric] for val in results.values()])
x = np.array(x).transpose() * np.array([-1, 1, -1, -1, -1])
z = (x - np.mean(x, axis=0)) / np.std(x, axis=0)
print(f"covariance matrix for normalized scores of {metrics}:")
print(np.cov(z.transpose()))
score = np.mean(z, axis=1)
for i, key in enumerate(results.keys()):
results[key]['mix'] = score[i].item()
if __name__ == "__main__":
args = parser.parse_args()
uuids = sorted([x for x in os.listdir(args.input) if os.path.isdir(os.path.join(args.input, x)) and is_uuid(x)])
results = OrderedDict()
for uuid in uuids:
results[uuid] = collect_results(os.path.join(args.input, uuid))
add_mix_metric(results)
ranks = get_ranks(results)
csv = args.csv if args.csv.endswith('.csv') else args.csv + '.csv'
print_csv(args.csv, results, args.tag, ranks=ranks)
with open(csv[:-4] + '.pickle', 'wb') as f:
pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL)

View file

@ -0,0 +1,52 @@
#!/bin/bash
case $# in
9) SETUP=$1; OUTDIR=$2; NAME=$3; DEVICE=$4; ROUNDS=$5; LPCNEXT=$6; LPCNET=$7; TESTSUITE=$8; TESTITEMS=$9;;
*) echo "loop_run.sh setup outdir name device rounds lpcnext_repo lpcnet_repo testsuite_repo testitems"; exit;;
esac
PYTHON="/home/ubuntu/opt/miniconda3/envs/torch/bin/python"
TESTFEATURES=${LPCNEXT}/testitems/features/all_0_orig_features.f32
WARPQREFERENCE=${LPCNEXT}/testitems/wav/all_0_orig.wav
METRICS="warpq,pesq,pitch_error,voicing_error"
LPCNETDEMO=${LPCNET}/lpcnet_demo
for ((round = 1; round <= $ROUNDS; round++))
do
echo
echo round $round
UUID=$(uuidgen)
TRAINOUT=${OUTDIR}/${UUID}/training
TESTOUT=${OUTDIR}/${UUID}/testing
CHECKPOINT=${TRAINOUT}/checkpoints/checkpoint_last.pth
FINALCHECKPOINT=${TRAINOUT}/checkpoints/checkpoint_finalize_last.pth
# run training
echo "starting training..."
$PYTHON $LPCNEXT/train_lpcnet.py $SETUP $TRAINOUT --device $DEVICE --test-features $TESTFEATURES --warpq-reference $WARPQREFERENCE
# run finalization
echo "starting finalization..."
$PYTHON $LPCNEXT/train_lpcnet.py $SETUP $TRAINOUT \
--device $DEVICE --test-features $TESTFEATURES \
--warpq-reference $WARPQREFERENCE \
--finalize --initial-checkpoint $CHECKPOINT
# create test configs
$PYTHON $LPCNEXT/make_test_config.py ${OUTDIR}/${UUID}/testconfig.yml "$NAME $UUID" $CHECKPOINT --lpcnet-demo $LPCNETDEMO
$PYTHON $LPCNEXT/make_test_config.py ${OUTDIR}/${UUID}/testconfig_finalize.yml "$NAME $UUID finalized" $FINALCHECKPOINT --lpcnet-demo $LPCNETDEMO
# run tests
echo "starting test 1 (no finalization)..."
$PYTHON $TESTSUITE/run_test.py ${OUTDIR}/${UUID}/testconfig.yml \
$TESTITEMS ${TESTOUT}/prefinal --num-workers 8 \
--num-testitems 400 --metrics $METRICS
echo "starting test 2 (after finalization)..."
$PYTHON $TESTSUITE/run_test.py ${OUTDIR}/${UUID}/testconfig_finalize.yml \
$TESTITEMS ${TESTOUT}/final --num-workers 8 \
--num-testitems 400 --metrics $METRICS
done

View file

@ -0,0 +1,37 @@
""" script for creating animations from debug data
"""
import argparse
import sys
sys.path.append('./')
from utils.endoscopy import make_animation, read_data
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='endoscopy folder with debug output')
parser.add_argument('output', type=str, help='output file (will be auto-extended with .mp4)')
parser.add_argument('--start-index', type=int, help='index of first sample to be considered', default=0)
parser.add_argument('--stop-index', type=int, help='index of last sample to be considered', default=-1)
parser.add_argument('--interval', type=int, help='interval between frames in ms', default=20)
parser.add_argument('--half-window-length', type=int, help='half size of window for displaying signals', default=80)
if __name__ == "__main__":
args = parser.parse_args()
filename = args.output if args.output.endswith('.mp4') else args.output + '.mp4'
data = read_data(args.folder)
make_animation(
data,
filename,
start_index=args.start_index,
stop_index = args.stop_index,
half_signal_window_length=args.half_window_length
)

View file

@ -0,0 +1,17 @@
import argparse
import numpy as np
parser = argparse.ArgumentParser(description="sets s_t to augmented_s_t")
parser.add_argument('datafile', type=str, help='data.s16 file path')
args = parser.parse_args()
data = np.memmap(args.datafile, dtype='int16', mode='readwrite')
# signal is in data[1::2]
# last augmented signal is in data[0::2]
data[1 : - 1 : 2] = data[2 : : 2]

View file

@ -0,0 +1,17 @@
#!/bin/bash
case $# in
9) SETUP=$1; OUTDIR=$2; NAME=$3; NUMDEVICES=$4; ROUNDS=$5; LPCNEXT=$6; LPCNET=$7; TESTSUITE=$8; TESTITEMS=$9;;
*) echo "multi_run.sh setup outdir name num_devices rounds_per_device lpcnext_repo lpcnet_repo testsuite_repo testitems"; exit;;
esac
LOOPRUN=${LPCNEXT}/loop_run.sh
mkdir -p $OUTDIR
for ((i = 0; i < $NUMDEVICES; i++))
do
echo "launching job queue for device $i"
nohup bash $LOOPRUN $SETUP $OUTDIR "$NAME" "cuda:$i" $ROUNDS $LPCNEXT $LPCNET $TESTSUITE $TESTITEMS > $OUTDIR/job_${i}_out.txt &
done

View file

@ -0,0 +1,22 @@
#!/bin/bash
case $# in
3) FEATURES=$1; FOLDER=$2; PYTHON=$3;;
*) echo "run_inference_test.sh <features file> <output folder> <python path>"; exit;;
esac
SCRIPTFOLDER=$(dirname "$0")
mkdir -p $FOLDER/inference_test
# update checkpoints
for fn in $(find $FOLDER -type f -name "checkpoint*.pth")
do
tmp=$(basename $fn)
tmp=${tmp%.pth}
epoch=${tmp#checkpoint_epoch_}
echo "running inference with checkpoint $fn..."
$PYTHON $SCRIPTFOLDER/../test_lpcnet.py $FEATURES $fn $FOLDER/inference_test/output_epoch_${epoch}.wav
done

View file

@ -0,0 +1,25 @@
""" script for updating checkpoints with new setup entries
Use this script to update older outputs with newly introduced
parameters. (Saves us the trouble of backward compatibility)
"""
import argparse
import torch
parser = argparse.ArgumentParser()
parser.add_argument('checkpoint_file', type=str, help='checkpoint to be updated')
parser.add_argument('--model', type=str, help='model update', default=None)
args = parser.parse_args()
checkpoint = torch.load(args.checkpoint_file, map_location='cpu')
# update model entry
if type(args.model) != type(None):
checkpoint['setup']['lpcnet']['model'] = args.model
torch.save(checkpoint, args.checkpoint_file)

View file

@ -0,0 +1,22 @@
#!/bin/bash
case $# in
3) FOLDER=$1; MODEL=$2; PYTHON=$3;;
*) echo "update_output_folder.sh folder model python"; exit;;
esac
SCRIPTFOLDER=$(dirname "$0")
# update setup
echo "updating $FOLDER/setup.py..."
$PYTHON $SCRIPTFOLDER/update_setups.py $FOLDER/setup.yml --model $MODEL
# update checkpoints
for fn in $(find $FOLDER -type f -name "checkpoint*.pth")
do
echo "updating $fn..."
$PYTHON $SCRIPTFOLDER/update_checkpoints.py $fn --model $MODEL
done

View file

@ -0,0 +1,28 @@
""" script for updating setup files with new setup entries
Use this script to update older outputs with newly introduced
parameters. (Saves us the trouble of backward compatibility)
"""
import argparse
import yaml
parser = argparse.ArgumentParser()
parser.add_argument('setup_file', type=str, help='setup to be updated')
parser.add_argument('--model', type=str, help='model update', default=None)
args = parser.parse_args()
# load setup
with open(args.setup_file, 'r') as f:
setup = yaml.load(f.read(), yaml.FullLoader)
# update model entry
if type(args.model) != type(None):
setup['lpcnet']['model'] = args.model
# dump result
with open(args.setup_file, 'w') as f:
yaml.dump(setup, f)

View file

@ -0,0 +1,60 @@
import argparse
import torch
import numpy as np
from models import model_dict
from utils.data import load_features
from utils.wav import wavwrite16
debug = False
if debug:
args = type('dummy', (object,),
{
'features' : 'features.f32',
'checkpoint' : 'checkpoint.pth',
'output' : 'out.wav',
'version' : 2
})()
else:
parser = argparse.ArgumentParser()
parser.add_argument('features', type=str, help='feature file')
parser.add_argument('checkpoint', type=str, help='checkpoint file')
parser.add_argument('output', type=str, help='output file')
parser.add_argument('--version', type=int, help='feature version', default=2)
args = parser.parse_args()
torch.set_num_threads(2)
version = args.version
feature_file = args.features
checkpoint_file = args.checkpoint
output_file = args.output
if not output_file.endswith('.wav'):
output_file += '.wav'
checkpoint = torch.load(checkpoint_file, map_location="cpu")
# check model
if not 'model' in checkpoint['setup']['lpcnet']:
print(f'warning: did not find model entry in setup, using default lpcnet')
model_name = 'lpcnet'
else:
model_name = checkpoint['setup']['lpcnet']['model']
model = model_dict[model_name](checkpoint['setup']['lpcnet']['config'])
model.load_state_dict(checkpoint['state_dict'])
data = load_features(feature_file)
output = model.generate(data['features'], data['periods'], data['lpcs'])
wavwrite16(output_file, output.numpy(), 16000)

View file

@ -0,0 +1,243 @@
import os
import argparse
import sys
try:
import git
has_git = True
except:
has_git = False
import yaml
import torch
from torch.optim.lr_scheduler import LambdaLR
from data import LPCNetDataset
from models import model_dict
from engine.lpcnet_engine import train_one_epoch, evaluate
from utils.data import load_features
from utils.wav import wavwrite16
debug = False
if debug:
args = type('dummy', (object,),
{
'setup' : 'setup.yml',
'output' : 'testout',
'device' : None,
'test_features' : None,
'finalize': False,
'initial_checkpoint': None,
'no-redirect': False
})()
else:
parser = argparse.ArgumentParser("train_lpcnet.py")
parser.add_argument('setup', type=str, help='setup yaml file')
parser.add_argument('output', type=str, help='output path')
parser.add_argument('--device', type=str, help='compute device', default=None)
parser.add_argument('--test-features', type=str, help='test feature file in v2 format', default=None)
parser.add_argument('--finalize', action='store_true', help='run single training round with lr=1e-5')
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of output')
args = parser.parse_args()
torch.set_num_threads(4)
with open(args.setup, 'r') as f:
setup = yaml.load(f.read(), yaml.FullLoader)
if args.finalize:
if args.initial_checkpoint is None:
raise ValueError('finalization requires initial checkpoint')
if 'sparsification' in setup['lpcnet']['config']:
for sp_job in setup['lpcnet']['config']['sparsification'].values():
sp_job['start'], sp_job['stop'] = 0, 0
setup['training']['lr'] = 1.0e-5
setup['training']['lr_decay_factor'] = 0.0
setup['training']['epochs'] = 1
checkpoint_prefix = 'checkpoint_finalize'
output_prefix = 'output_finalize'
setup_name = 'setup_finalize.yml'
output_file='out_finalize.txt'
else:
checkpoint_prefix = 'checkpoint'
output_prefix = 'output'
setup_name = 'setup.yml'
output_file='out.txt'
# check model
if not 'model' in setup['lpcnet']:
print(f'warning: did not find model entry in setup, using default lpcnet')
model_name = 'lpcnet'
else:
model_name = setup['lpcnet']['model']
# prepare output folder
if os.path.exists(args.output) and not debug and not args.finalize:
print("warning: output folder exists")
reply = input('continue? (y/n): ')
while reply not in {'y', 'n'}:
reply = input('continue? (y/n): ')
if reply == 'n':
os._exit()
else:
os.makedirs(args.output, exist_ok=True)
checkpoint_dir = os.path.join(args.output, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
# add repo info to setup
if has_git:
working_dir = os.path.split(__file__)[0]
try:
repo = git.Repo(working_dir)
setup['repo'] = dict()
hash = repo.head.object.hexsha
urls = list(repo.remote().urls)
is_dirty = repo.is_dirty()
if is_dirty:
print("warning: repo is dirty")
setup['repo']['hash'] = hash
setup['repo']['urls'] = urls
setup['repo']['dirty'] = is_dirty
except:
has_git = False
# dump setup
with open(os.path.join(args.output, setup_name), 'w') as f:
yaml.dump(setup, f)
# prepare inference test if wanted
run_inference_test = False
if type(args.test_features) != type(None):
test_features = load_features(args.test_features)
inference_test_dir = os.path.join(args.output, 'inference_test')
os.makedirs(inference_test_dir, exist_ok=True)
run_inference_test = True
# training parameters
batch_size = setup['training']['batch_size']
epochs = setup['training']['epochs']
lr = setup['training']['lr']
lr_decay_factor = setup['training']['lr_decay_factor']
# load training dataset
lpcnet_config = setup['lpcnet']['config']
data = LPCNetDataset( setup['dataset'],
features=lpcnet_config['features'],
input_signals=lpcnet_config['signals'],
target=lpcnet_config['target'],
frames_per_sample=setup['training']['frames_per_sample'],
feature_history=lpcnet_config['feature_history'],
feature_lookahead=lpcnet_config['feature_lookahead'],
lpc_gamma=lpcnet_config.get('lpc_gamma', 1))
# load validation dataset if given
if 'validation_dataset' in setup:
validation_data = LPCNetDataset( setup['validation_dataset'],
features=lpcnet_config['features'],
input_signals=lpcnet_config['signals'],
target=lpcnet_config['target'],
frames_per_sample=setup['training']['frames_per_sample'],
feature_history=lpcnet_config['feature_history'],
feature_lookahead=lpcnet_config['feature_lookahead'],
lpc_gamma=lpcnet_config.get('lpc_gamma', 1))
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)
run_validation = True
else:
run_validation = False
# create model
model = model_dict[model_name](setup['lpcnet']['config'])
if args.initial_checkpoint is not None:
print(f"loading state dict from {args.initial_checkpoint}...")
chkpt = torch.load(args.initial_checkpoint, map_location='cpu')
model.load_state_dict(chkpt['state_dict'])
# set compute device
if type(args.device) == type(None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
# push model to device
model.to(device)
# dataloader
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
# optimizer is introduced to trainable parameters
parameters = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(parameters, lr=lr)
# learning rate scheduler
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
# loss
criterion = torch.nn.NLLLoss()
# model checkpoint
checkpoint = {
'setup' : setup,
'state_dict' : model.state_dict(),
'loss' : -1
}
if not args.no_redirect:
print(f"re-directing output to {os.path.join(args.output, output_file)}")
sys.stdout = open(os.path.join(args.output, output_file), "w")
best_loss = 1e9
for ep in range(1, epochs + 1):
print(f"training epoch {ep}...")
new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler)
# save checkpoint
checkpoint['state_dict'] = model.state_dict()
checkpoint['loss'] = new_loss
if run_validation:
print("running validation...")
validation_loss = evaluate(model, criterion, validation_dataloader, device)
checkpoint['validation_loss'] = validation_loss
if validation_loss < best_loss:
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth'))
best_loss = validation_loss
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
# run inference test
if run_inference_test:
model.to("cpu")
print("running inference test...")
output = model.generate(test_features['features'], test_features['periods'], test_features['lpcs'])
testfilename = os.path.join(inference_test_dir, output_prefix + f'_epoch_{ep}.wav')
wavwrite16(testfilename, output.numpy(), 16000)
model.to(device)
print()

View file

@ -0,0 +1,4 @@
from . import sparsification
from . import data
from . import pcm
from . import sample

View file

@ -0,0 +1,112 @@
import os
import torch
import numpy as np
def load_features(feature_file, version=2):
if version == 2:
layout = {
'cepstrum': [0,18],
'periods': [18, 19],
'pitch_corr': [19, 20],
'lpc': [20, 36]
}
frame_length = 36
elif version == 1:
layout = {
'cepstrum': [0,18],
'periods': [36, 37],
'pitch_corr': [37, 38],
'lpc': [39, 55],
}
frame_length = 55
else:
raise ValueError(f'unknown feature version: {version}')
raw_features = torch.from_numpy(np.fromfile(feature_file, dtype='float32'))
raw_features = raw_features.reshape((-1, frame_length))
features = torch.cat(
[
raw_features[:, layout['cepstrum'][0] : layout['cepstrum'][1]],
raw_features[:, layout['pitch_corr'][0] : layout['pitch_corr'][1]]
],
dim=1
)
lpcs = raw_features[:, layout['lpc'][0] : layout['lpc'][1]]
periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long()
return {'features' : features, 'periods' : periods, 'lpcs' : lpcs}
def create_new_data(signal_path, reference_data_path, new_data_path, offset=320, preemph_factor=0.85):
ref_data = np.memmap(reference_data_path, dtype=np.int16)
signal = np.memmap(signal_path, dtype=np.int16)
signal_preemph_path = os.path.splitext(signal_path)[0] + '_preemph.raw'
signal_preemph = np.memmap(signal_preemph_path, dtype=np.int16, mode='write', shape=signal.shape)
assert len(signal) % 160 == 0
num_frames = len(signal) // 160
mem = np.zeros(1)
for fr in range(len(signal)//160):
signal_preemph[fr * 160 : (fr + 1) * 160] = np.convolve(np.concatenate((mem, signal[fr * 160 : (fr + 1) * 160])), [1, -preemph_factor], mode='valid')
mem = signal[(fr + 1) * 160 - 1 : (fr + 1) * 160]
new_data = np.memmap(new_data_path, dtype=np.int16, mode='write', shape=ref_data.shape)
new_data[:] = 0
N = len(signal) - offset
new_data[1 : 2*N + 1: 2] = signal_preemph[offset:]
new_data[2 : 2*N + 2: 2] = signal_preemph[offset:]
def parse_warpq_scores(output_file):
""" extracts warpq scores from output file """
with open(output_file, "r") as f:
lines = f.readlines()
scores = [float(line.split("WARP-Q score:")[-1]) for line in lines if line.startswith("WARP-Q score:")]
return scores
def parse_stats_file(file):
with open(file, "r") as f:
lines = f.readlines()
mean = float(lines[0].split(":")[-1])
bt_mean = float(lines[1].split(":")[-1])
top_mean = float(lines[2].split(":")[-1])
return mean, bt_mean, top_mean
def collect_test_stats(test_folder):
""" collects statistics for all discovered metrics from test folder """
metrics = {'pesq', 'warpq', 'pitch_error', 'voicing_error'}
results = dict()
content = os.listdir(test_folder)
stats_files = [file for file in content if file.startswith('stats_')]
for file in stats_files:
metric = file[len("stats_") : -len(".txt")]
if metric not in metrics:
print(f"warning: unknown metric {metric}")
mean, bt_mean, top_mean = parse_stats_file(os.path.join(test_folder, file))
results[metric] = [mean, bt_mean, top_mean]
return results

View file

@ -0,0 +1,205 @@
""" module for inspecting models during inference """
import os
import yaml
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch
import numpy as np
# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
_state = dict()
_folder = 'endoscopy'
def get_gru_gates(gru, input, state):
hidden_size = gru.hidden_size
direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
# reset gate
start, stop = 0 * hidden_size, 1 * hidden_size
reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
# update gate
start, stop = 1 * hidden_size, 2 * hidden_size
update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
# new gate
start, stop = 2 * hidden_size, 3 * hidden_size
new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop]))
return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
def init(folder='endoscopy'):
""" sets up output folder for endoscopy data """
global _folder
_folder = folder
if not os.path.exists(folder):
os.makedirs(folder)
else:
print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
def write_data(key, data, fs):
""" appends data to previous data written under key """
global _state
# convert to numpy if torch.Tensor is given
if isinstance(data, torch.Tensor):
data = data.detach().numpy()
if not key in _state:
_state[key] = {
'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'),
'fs' : fs,
'dim' : tuple(data.shape),
'dtype' : str(data.dtype)
}
with open(os.path.join(_folder, key + '.yml'), 'w') as f:
f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
else:
if _state[key]['fs'] != fs:
raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
if _state[key]['dtype'] != str(data.dtype):
raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
if _state[key]['dim'] != tuple(data.shape):
raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
_state[key]['fid'].write(data.tobytes())
def close(folder='endoscopy'):
""" clean up """
for key in _state.keys():
_state[key]['fid'].close()
def read_data(folder='endoscopy'):
""" retrieves written data as numpy arrays """
keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
return_dict = dict()
for key in keys:
with open(os.path.join(folder, key + '.yml'), 'r') as f:
value = yaml.load(f.read(), yaml.FullLoader)
with open(os.path.join(folder, key + '.bin'), 'rb') as f:
data = np.frombuffer(f.read(), dtype=value['dtype'])
value['data'] = data.reshape((-1,) + value['dim'])
return_dict[key] = value
return return_dict
def get_best_reshape(shape, target_ratio=1):
""" calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
if len(shape) > 1:
pixel_count = 1
for s in shape:
pixel_count *= s
else:
pixel_count = shape[0]
if pixel_count == 1:
return (1,)
num_columns = int((pixel_count / target_ratio)**.5)
while (pixel_count % num_columns):
num_columns -= 1
num_rows = pixel_count // num_columns
return (num_rows, num_columns)
def get_type_and_shape(shape):
# can happen if data is one dimensional
if len(shape) == 0:
shape = (1,)
# calculate pixel count
if len(shape) > 1:
pixel_count = 1
for s in shape:
pixel_count *= s
else:
pixel_count = shape[0]
if pixel_count == 1:
return 'plot', (1, )
# stay with shape if already 2-dimensional
if len(shape) == 2:
if (shape[0] != pixel_count) or (shape[1] != pixel_count):
return 'image', shape
return 'image', get_best_reshape(shape)
def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
# determine plot setup
num_keys = len(data.keys())
num_rows = int((num_keys * 3/4) ** .5)
num_cols = (num_keys + num_rows - 1) // num_rows
fig, axs = plt.subplots(num_rows, num_cols)
fig.set_size_inches(num_cols * 5, num_rows * 5)
display = dict()
fs_max = max([val['fs'] for val in data.values()])
num_samples = max([val['data'].shape[0] for val in data.values()])
keys = sorted(data.keys())
# inspect data
for i, key in enumerate(keys):
axs[i // num_cols, i % num_cols].title.set_text(key)
display[key] = dict()
display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
display[key]['down_factor'] = data[key]['fs'] / fs_max
start_index = max(start_index, half_signal_window_length)
while stop_index < 0:
stop_index += num_samples
stop_index = min(stop_index, num_samples - half_signal_window_length)
# actual plotting
frames = []
for index in range(start_index, stop_index):
ims = []
for i, key in enumerate(keys):
feature_index = int(round(index * display[key]['down_factor']))
if display[key]['type'] == 'plot':
ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
elif display[key]['type'] == 'image':
ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
frames.append(ims)
ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
if not filename.endswith('.mp4'):
filename += '.mp4'
ani.save(filename)

View file

@ -0,0 +1,3 @@
from .dual_fc import DualFC
from .subconditioner import AdditiveSubconditioner, ModulativeSubconditioner, ConcatenativeSubconditioner
from .pcm_embeddings import PCMEmbedding, DifferentiablePCMEmbedding

View file

@ -0,0 +1,15 @@
import torch
from torch import nn
class DualFC(nn.Module):
def __init__(self, input_dim, output_dim):
super(DualFC, self).__init__()
self.dense1 = nn.Linear(input_dim, output_dim)
self.dense2 = nn.Linear(input_dim, output_dim)
self.alpha = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
self.beta = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
def forward(self, x):
return self.alpha * torch.tanh(self.dense1(x)) + self.beta * torch.tanh(self.dense2(x))

View file

@ -0,0 +1,42 @@
""" module implementing PCM embeddings for LPCNet """
import math as m
import torch
from torch import nn
class PCMEmbedding(nn.Module):
def __init__(self, embed_dim=128, num_levels=256):
super(PCMEmbedding, self).__init__()
self.embed_dim = embed_dim
self.num_levels = num_levels
self.embedding = nn.Embedding(self.num_levels, self.num_dim)
# initialize
with torch.no_grad():
num_rows, num_cols = self.num_levels, self.embed_dim
a = m.sqrt(12) * (torch.rand(num_rows, num_cols) - 0.5)
for i in range(num_rows):
a[i, :] += m.sqrt(12) * (i - num_rows / 2)
self.embedding.weight[:, :] = 0.1 * a
def forward(self, x):
return self.embeddint(x)
class DifferentiablePCMEmbedding(PCMEmbedding):
def __init__(self, embed_dim, num_levels=256):
super(DifferentiablePCMEmbedding, self).__init__(embed_dim, num_levels)
def forward(self, x):
x_int = (x - torch.floor(x)).detach().long()
x_frac = x - x_int
x_next = torch.minimum(x_int + 1, self.num_levels)
embed_0 = self.embedding(x_int)
embed_1 = self.embedding(x_next)
return (1 - x_frac) * embed_0 + x_frac * embed_1

View file

@ -0,0 +1,468 @@
from re import sub
import torch
from torch import nn
def get_subconditioner( method,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
**kwargs):
subconditioner_dict = {
'additive' : AdditiveSubconditioner,
'concatenative' : ConcatenativeSubconditioner,
'modulative' : ModulativeSubconditioner
}
return subconditioner_dict[method](number_of_subsamples,
pcm_embedding_size, state_size, pcm_levels, number_of_signals, **kwargs)
class Subconditioner(nn.Module):
def __init__(self):
""" upsampling by subconditioning
Upsamples a sequence of states conditioning on pcm signals and
optionally a feature vector.
"""
super(Subconditioner, self).__init__()
def forward(self, states, signals, features=None):
raise Exception("Base class should not be called")
def single_step(self, index, state, signals, features):
raise Exception("Base class should not be called")
def get_output_dim(self, index):
raise Exception("Base class should not be called")
class AdditiveSubconditioner(Subconditioner):
def __init__(self,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
**kwargs):
""" subconditioning by addition """
super(AdditiveSubconditioner, self).__init__()
self.number_of_subsamples = number_of_subsamples
self.pcm_embedding_size = pcm_embedding_size
self.state_size = state_size
self.pcm_levels = pcm_levels
self.number_of_signals = number_of_signals
if self.pcm_embedding_size != self.state_size:
raise ValueError('For additive subconditioning state and embedding '
+ f'sizes must match but but got {self.state_size} and {self.pcm_embedding_size}')
self.embeddings = [None]
for i in range(1, self.number_of_subsamples):
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
self.add_module('pcm_embedding_' + str(i), embedding)
self.embeddings.append(embedding)
def forward(self, states, signals):
""" creates list of subconditioned states
Parameters:
-----------
states : torch.tensor
states of shape (batch, seq_length // s, state_size)
signals : torch.tensor
signals of shape (batch, seq_length, number_of_signals)
Returns:
--------
c_states : list of torch.tensor
list of s subconditioned states
"""
s = self.number_of_subsamples
c_states = [states]
new_states = states
for i in range(1, self.number_of_subsamples):
embed = self.embeddings[i](signals[:, i::s])
# reduce signal dimension
embed = torch.sum(embed, dim=2)
new_states = new_states + embed
c_states.append(new_states)
return c_states
def single_step(self, index, state, signals):
""" carry out single step for inference
Parameters:
-----------
index : int
position in subconditioning batch
state : torch.tensor
state to sub-condition
signals : torch.tensor
signals for subconditioning, all but the last dimensions
must match those of state
Returns:
c_state : torch.tensor
subconditioned state
"""
if index == 0:
c_state = state
else:
embed_signals = self.embeddings[index](signals)
c = torch.sum(embed_signals, dim=-2)
c_state = state + c
return c_state
def get_output_dim(self, index):
return self.state_size
def get_average_flops_per_step(self):
s = self.number_of_subsamples
flops = (s - 1) / s * self.number_of_signals * self.pcm_embedding_size
return flops
class ConcatenativeSubconditioner(Subconditioner):
def __init__(self,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
recurrent=True,
**kwargs):
""" subconditioning by concatenation """
super(ConcatenativeSubconditioner, self).__init__()
self.number_of_subsamples = number_of_subsamples
self.pcm_embedding_size = pcm_embedding_size
self.state_size = state_size
self.pcm_levels = pcm_levels
self.number_of_signals = number_of_signals
self.recurrent = recurrent
self.embeddings = []
start_index = 0
if self.recurrent:
start_index = 1
self.embeddings.append(None)
for i in range(start_index, self.number_of_subsamples):
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
self.add_module('pcm_embedding_' + str(i), embedding)
self.embeddings.append(embedding)
def forward(self, states, signals):
""" creates list of subconditioned states
Parameters:
-----------
states : torch.tensor
states of shape (batch, seq_length // s, state_size)
signals : torch.tensor
signals of shape (batch, seq_length, number_of_signals)
Returns:
--------
c_states : list of torch.tensor
list of s subconditioned states
"""
s = self.number_of_subsamples
if self.recurrent:
c_states = [states]
start = 1
else:
c_states = []
start = 0
new_states = states
for i in range(start, self.number_of_subsamples):
embed = self.embeddings[i](signals[:, i::s])
# reduce signal dimension
embed = torch.flatten(embed, -2)
if self.recurrent:
new_states = torch.cat((new_states, embed), dim=-1)
else:
new_states = torch.cat((states, embed), dim=-1)
c_states.append(new_states)
return c_states
def single_step(self, index, state, signals):
""" carry out single step for inference
Parameters:
-----------
index : int
position in subconditioning batch
state : torch.tensor
state to sub-condition
signals : torch.tensor
signals for subconditioning, all but the last dimensions
must match those of state
Returns:
c_state : torch.tensor
subconditioned state
"""
if index == 0 and self.recurrent:
c_state = state
else:
embed_signals = self.embeddings[index](signals)
c = torch.flatten(embed_signals, -2)
if not self.recurrent and index > 0:
# overwrite previous conditioning vector
c_state = torch.cat((state[...,:self.state_size], c), dim=-1)
else:
c_state = torch.cat((state, c), dim=-1)
return c_state
return c_state
def get_average_flops_per_step(self):
return 0
def get_output_dim(self, index):
if self.recurrent:
return self.state_size + index * self.pcm_embedding_size * self.number_of_signals
else:
return self.state_size + self.pcm_embedding_size * self.number_of_signals
class ModulativeSubconditioner(Subconditioner):
def __init__(self,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
state_recurrent=False,
**kwargs):
""" subconditioning by modulation """
super(ModulativeSubconditioner, self).__init__()
self.number_of_subsamples = number_of_subsamples
self.pcm_embedding_size = pcm_embedding_size
self.state_size = state_size
self.pcm_levels = pcm_levels
self.number_of_signals = number_of_signals
self.state_recurrent = state_recurrent
self.hidden_size = self.pcm_embedding_size * self.number_of_signals
if self.state_recurrent:
self.hidden_size += self.pcm_embedding_size
self.state_transform = nn.Linear(self.state_size, self.pcm_embedding_size)
self.embeddings = [None]
self.alphas = [None]
self.betas = [None]
for i in range(1, self.number_of_subsamples):
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
self.add_module('pcm_embedding_' + str(i), embedding)
self.embeddings.append(embedding)
self.alphas.append(nn.Linear(self.hidden_size, self.state_size))
self.add_module('alpha_dense_' + str(i), self.alphas[-1])
self.betas.append(nn.Linear(self.hidden_size, self.state_size))
self.add_module('beta_dense_' + str(i), self.betas[-1])
def forward(self, states, signals):
""" creates list of subconditioned states
Parameters:
-----------
states : torch.tensor
states of shape (batch, seq_length // s, state_size)
signals : torch.tensor
signals of shape (batch, seq_length, number_of_signals)
Returns:
--------
c_states : list of torch.tensor
list of s subconditioned states
"""
s = self.number_of_subsamples
c_states = [states]
new_states = states
for i in range(1, self.number_of_subsamples):
embed = self.embeddings[i](signals[:, i::s])
# reduce signal dimension
embed = torch.flatten(embed, -2)
if self.state_recurrent:
comp_states = self.state_transform(new_states)
embed = torch.cat((embed, comp_states), dim=-1)
alpha = torch.tanh(self.alphas[i](embed))
beta = torch.tanh(self.betas[i](embed))
# new state obtained by modulating previous state
new_states = torch.tanh((1 + alpha) * new_states + beta)
c_states.append(new_states)
return c_states
def single_step(self, index, state, signals):
""" carry out single step for inference
Parameters:
-----------
index : int
position in subconditioning batch
state : torch.tensor
state to sub-condition
signals : torch.tensor
signals for subconditioning, all but the last dimensions
must match those of state
Returns:
c_state : torch.tensor
subconditioned state
"""
if index == 0:
c_state = state
else:
embed_signals = self.embeddings[index](signals)
c = torch.flatten(embed_signals, -2)
if self.state_recurrent:
r_state = self.state_transform(state)
c = torch.cat((c, r_state), dim=-1)
alpha = torch.tanh(self.alphas[index](c))
beta = torch.tanh(self.betas[index](c))
c_state = torch.tanh((1 + alpha) * state + beta)
return c_state
return c_state
def get_output_dim(self, index):
return self.state_size
def get_average_flops_per_step(self):
s = self.number_of_subsamples
# estimate activation by 10 flops
# c_state = torch.tanh((1 + alpha) * state + beta)
flops = 13 * self.state_size
# hidden size
hidden_size = self.number_of_signals * self.pcm_embedding_size
if self.state_recurrent:
hidden_size += self.pcm_embedding_size
# counting 2 * A * B flops for Linear(A, B)
# alpha = torch.tanh(self.alphas[index](c))
# beta = torch.tanh(self.betas[index](c))
flops += 4 * hidden_size * self.state_size + 20 * self.state_size
# r_state = self.state_transform(state)
if self.state_recurrent:
flops += 2 * self.state_size * self.pcm_embedding_size
# average over steps
flops *= (s - 1) / s
return flops
class ComparitiveSubconditioner(Subconditioner):
def __init__(self,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
error_index=-1,
apply_gate=True,
normalize=False):
""" subconditioning by comparison """
super(ComparitiveSubconditioner, self).__init__()
self.comparison_size = self.pcm_embedding_size
self.error_position = error_index
self.apply_gate = apply_gate
self.normalize = normalize
self.state_transform = nn.Linear(self.state_size, self.comparison_size)
self.alpha_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
self.beta_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
if self.apply_gate:
self.gate_dense = nn.Linear(self.pcm_embedding_size, self.state_size)
# embeddings and state transforms
self.embeddings = [None]
self.alpha_denses = [None]
self.beta_denses = [None]
self.state_transforms = [nn.Linear(self.state_size, self.comparison_size)]
self.add_module('state_transform_0', self.state_transforms[0])
for i in range(1, self.number_of_subsamples):
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
self.add_module('pcm_embedding_' + str(i), embedding)
self.embeddings.append(embedding)
state_transform = nn.Linear(self.state_size, self.comparison_size)
self.add_module('state_transform_' + str(i), state_transform)
self.state_transforms.append(state_transform)
self.alpha_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
self.add_module('alpha_dense_' + str(i), self.alpha_denses[-1])
self.beta_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
self.add_module('beta_dense_' + str(i), self.beta_denses[-1])
def forward(self, states, signals):
s = self.number_of_subsamples
c_states = [states]
new_states = states
for i in range(1, self.number_of_subsamples):
embed = self.embeddings[i](signals[:, i::s])
# reduce signal dimension
embed = torch.flatten(embed, -2)
comp_states = self.state_transforms[i](new_states)
alpha = torch.tanh(self.alpha_dense(embed))
beta = torch.tanh(self.beta_dense(embed))
# new state obtained by modulating previous state
new_states = torch.tanh((1 + alpha) * comp_states + beta)
c_states.append(new_states)
return c_states

View file

@ -0,0 +1,36 @@
import torch
def find(a, v):
try:
idx = a.index(v)
except:
idx = -1
return idx
def interleave_tensors(tensors, dim=-2):
""" interleave list of tensors along sequence dimension """
x = torch.cat([x.unsqueeze(dim) for x in tensors], dim=dim)
x = torch.flatten(x, dim - 1, dim)
return x
def _interleave(x, pcm_levels=256):
repeats = pcm_levels // (2*x.size(-1))
x = x.unsqueeze(-1)
p = torch.flatten(torch.repeat_interleave(torch.cat((x, 1 - x), dim=-1), repeats, dim=-1), -2)
return p
def get_pdf_from_tree(x):
pcm_levels = x.size(-1)
p = _interleave(x[..., 1:2])
n = 4
while n <= pcm_levels:
p = p * _interleave(x[..., n//2:n])
n *= 2
return p

View file

@ -0,0 +1,6 @@
def clip_to_int16(x):
int_min = -2**15
int_max = 2**15 - 1
x_clipped = max(int_min, min(x, int_max))
return x_clipped

View file

@ -0,0 +1,15 @@
import torch
def sample_excitation(probs, pitch_corr):
norm = lambda x : x / (x.sum() + 1e-18)
# lowering the temperature
probs = norm(probs ** (1 + max(0, 1.5 * pitch_corr - 0.5)))
# cut-off tails
probs = norm(torch.maximum(probs - 0.002 , torch.FloatTensor([0])))
# sample
exc = torch.multinomial(probs.squeeze(), 1)
return exc

View file

@ -0,0 +1,2 @@
from .gru_sparsifier import GRUSparsifier
from .common import sparsify_matrix, calculate_gru_flops_per_step

View file

@ -0,0 +1,92 @@
import torch
def sparsify_matrix(matrix : torch.tensor, density : float, block_size : list[int, int], keep_diagonal : bool=False, return_mask : bool=False):
""" sparsifies matrix with specified block size
Parameters:
-----------
matrix : torch.tensor
matrix to sparsify
density : int
target density
block_size : [int, int]
block size dimensions
keep_diagonal : bool
If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False
"""
m, n = matrix.shape
m1, n1 = block_size
if m % m1 or n % n1:
raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}")
# extract diagonal if keep_diagonal = True
if keep_diagonal:
if m != n:
raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True")
to_spare = torch.diag(torch.diag(matrix))
matrix = matrix - to_spare
else:
to_spare = torch.zeros_like(matrix)
# calculate energy in sub-blocks
x = torch.reshape(matrix, (m // m1, m1, n // n1, n1))
x = x ** 2
block_energies = torch.sum(torch.sum(x, dim=3), dim=1)
number_of_blocks = (m * n) // (m1 * n1)
number_of_survivors = round(number_of_blocks * density)
# masking threshold
if number_of_survivors == 0:
threshold = 0
else:
threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors]
# create mask
mask = torch.ones_like(block_energies)
mask[block_energies < threshold] = 0
mask = torch.repeat_interleave(mask, m1, dim=0)
mask = torch.repeat_interleave(mask, n1, dim=1)
# perform masking
masked_matrix = mask * matrix + to_spare
if return_mask:
return masked_matrix, mask
else:
return masked_matrix
def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False):
input_size = gru.input_size
hidden_size = gru.hidden_size
flops = 0
input_density = (
sparsification_dict.get('W_ir', [1])[0]
+ sparsification_dict.get('W_in', [1])[0]
+ sparsification_dict.get('W_iz', [1])[0]
) / 3
recurrent_density = (
sparsification_dict.get('W_hr', [1])[0]
+ sparsification_dict.get('W_hn', [1])[0]
+ sparsification_dict.get('W_hz', [1])[0]
) / 3
# input matrix vector multiplications
if not drop_input:
flops += 2 * 3 * input_size * hidden_size * input_density
# recurrent matrix vector multiplications
flops += 2 * 3 * hidden_size * hidden_size * recurrent_density
# biases
flops += 6 * hidden_size
# activations estimated by 10 flops per activation
flops += 30 * hidden_size
return flops

View file

@ -0,0 +1,158 @@
import torch
from .common import sparsify_matrix
class GRUSparsifier:
def __init__(self, task_list, start, stop, interval, exponent=3):
""" Sparsifier for torch.nn.GRUs
Parameters:
-----------
task_list : list
task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance
of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in',
'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset,
update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal),
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal
should be kept.
start : int
training step after which sparsification will be started.
stop : int
training step after which sparsification will be completed.
interval : int
sparsification interval for steps between start and stop. After stop sparsification will be
carried out after every call to GRUSparsifier.step()
exponent : float
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
with density (alpha + target_density * (1 * alpha)), where
alpha = ((stop - i) / (start - stop)) ** exponent
Example:
--------
>>> import torch
>>> gru = torch.nn.GRU(10, 20)
>>> sparsify_dict = {
... 'W_ir' : (0.5, [2, 2], False),
... 'W_iz' : (0.6, [2, 2], False),
... 'W_in' : (0.7, [2, 2], False),
... 'W_hr' : (0.1, [4, 4], True),
... 'W_hz' : (0.2, [4, 4], True),
... 'W_hn' : (0.3, [4, 4], True),
... }
>>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50)
>>> for i in range(100):
... sparsifier.step()
"""
# just copying parameters...
self.start = start
self.stop = stop
self.interval = interval
self.exponent = exponent
self.task_list = task_list
# ... and setting counter to 0
self.step_counter = 0
self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}
def step(self, verbose=False):
""" carries out sparsification step
Call this function after optimizer.step in your
training loop.
Parameters:
----------
verbose : bool
if true, densities are printed out
Returns:
--------
None
"""
# compute current interpolation factor
self.step_counter += 1
if self.step_counter < self.start:
return
elif self.step_counter < self.stop:
# update only every self.interval-th interval
if self.step_counter % self.interval:
return
alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
else:
alpha = 0
with torch.no_grad():
for gru, params in self.task_list:
hidden_size = gru.hidden_size
# input weights
for i, key in enumerate(['W_ir', 'W_iz', 'W_in']):
if key in params:
density = alpha + (1 - alpha) * params[key][0]
if verbose:
print(f"[{self.step_counter}]: {key} density: {density}")
gru.weight_ih_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
gru.weight_ih_l0[i * hidden_size : (i + 1) * hidden_size, : ],
density, # density
params[key][1], # block_size
params[key][2], # keep_diagonal (might want to set this to False)
return_mask=True
)
if type(self.last_masks[key]) != type(None):
if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
print(f"sparsification mask {key} changed for gru {gru}")
self.last_masks[key] = new_mask
# recurrent weights
for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']):
if key in params:
density = alpha + (1 - alpha) * params[key][0]
if verbose:
print(f"[{self.step_counter}]: {key} density: {density}")
gru.weight_hh_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
gru.weight_hh_l0[i * hidden_size : (i + 1) * hidden_size, : ],
density,
params[key][1], # block_size
params[key][2], # keep_diagonal (might want to set this to False)
return_mask=True
)
if type(self.last_masks[key]) != type(None):
if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
print(f"sparsification mask {key} changed for gru {gru}")
self.last_masks[key] = new_mask
if __name__ == "__main__":
print("Testing sparsifier")
gru = torch.nn.GRU(10, 20)
sparsify_dict = {
'W_ir' : (0.5, [2, 2], False),
'W_iz' : (0.6, [2, 2], False),
'W_in' : (0.7, [2, 2], False),
'W_hr' : (0.1, [4, 4], True),
'W_hz' : (0.2, [4, 4], True),
'W_hn' : (0.3, [4, 4], True),
}
sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10)
for i in range(100):
sparsifier.step(verbose=True)

View file

@ -0,0 +1,128 @@
from models import multi_rate_lpcnet
import copy
setup_dict = dict()
dataset_template_v2 = {
'version' : 2,
'feature_file' : 'features.f32',
'signal_file' : 'data.s16',
'frame_length' : 160,
'feature_frame_length' : 36,
'signal_frame_length' : 2,
'feature_dtype' : 'float32',
'signal_dtype' : 'int16',
'feature_frame_layout' : {'cepstrum': [0,18], 'periods': [18, 19], 'pitch_corr': [19, 20], 'lpc': [20, 36]},
'signal_frame_layout' : {'last_signal' : 0, 'signal': 1} # signal, last_signal, error, prediction
}
dataset_template_v1 = {
'version' : 1,
'feature_file' : 'features.f32',
'signal_file' : 'data.u8',
'frame_length' : 160,
'feature_frame_length' : 55,
'signal_frame_length' : 4,
'feature_dtype' : 'float32',
'signal_dtype' : 'uint8',
'feature_frame_layout' : {'cepstrum': [0,18], 'periods': [36, 37], 'pitch_corr': [37, 38], 'lpc': [39, 55]},
'signal_frame_layout' : {'last_signal' : 0, 'prediction' : 1, 'last_error': 2, 'error': 3} # signal, last_signal, error, prediction
}
# lpcnet
lpcnet_config = {
'frame_size' : 160,
'gru_a_units' : 384,
'gru_b_units' : 64,
'feature_conditioning_dim' : 128,
'feature_conv_kernel_size' : 3,
'period_levels' : 257,
'period_embedding_dim' : 64,
'signal_embedding_dim' : 128,
'signal_levels' : 256,
'feature_dimension' : 19,
'output_levels' : 256,
'lpc_gamma' : 0.9,
'features' : ['cepstrum', 'periods', 'pitch_corr'],
'signals' : ['last_signal', 'prediction', 'last_error'],
'input_layout' : { 'signals' : {'last_signal' : 0, 'prediction' : 1, 'last_error' : 2},
'features' : {'cepstrum' : [0, 18], 'pitch_corr' : [18, 19]} },
'target' : 'error',
'feature_history' : 2,
'feature_lookahead' : 2,
'sparsification' : {
'gru_a' : {
'start' : 10000,
'stop' : 30000,
'interval' : 100,
'exponent' : 3,
'params' : {
'W_hr' : (0.05, [4, 8], True),
'W_hz' : (0.05, [4, 8], True),
'W_hn' : (0.2, [4, 8], True)
},
},
'gru_b' : {
'start' : 10000,
'stop' : 30000,
'interval' : 100,
'exponent' : 3,
'params' : {
'W_ir' : (0.5, [4, 8], False),
'W_iz' : (0.5, [4, 8], False),
'W_in' : (0.5, [4, 8], False)
},
}
},
'add_reference_phase' : False,
'reference_phase_dim' : 0
}
# multi rate
subconditioning = {
'subconditioning_a' : {
'number_of_subsamples' : 2,
'method' : 'modulative',
'signals' : ['last_signal', 'prediction', 'last_error'],
'pcm_embedding_size' : 64,
'kwargs' : dict()
},
'subconditioning_b' : {
'number_of_subsamples' : 2,
'method' : 'modulative',
'signals' : ['last_signal', 'prediction', 'last_error'],
'pcm_embedding_size' : 64,
'kwargs' : dict()
}
}
multi_rate_lpcnet_config = lpcnet_config.copy()
multi_rate_lpcnet_config['subconditioning'] = subconditioning
training_default = {
'batch_size' : 256,
'epochs' : 20,
'lr' : 1e-3,
'lr_decay_factor' : 2.5e-5,
'adam_betas' : [0.9, 0.99],
'frames_per_sample' : 15
}
lpcnet_setup = {
'dataset' : '/local/datasets/lpcnet_training',
'lpcnet' : {'config' : lpcnet_config, 'model': 'lpcnet'},
'training' : training_default
}
multi_rate_lpcnet_setup = copy.deepcopy(lpcnet_setup)
multi_rate_lpcnet_setup['lpcnet']['config'] = multi_rate_lpcnet_config
multi_rate_lpcnet_setup['lpcnet']['model'] = 'multi_rate'
setup_dict = {
'lpcnet' : lpcnet_setup,
'multi_rate' : multi_rate_lpcnet_setup
}

View file

@ -0,0 +1,29 @@
import math as m
import torch
def ulaw2lin(u):
scale_1 = 32768.0 / 255.0
u = u - 128
s = torch.sign(u)
u = torch.abs(u)
return s * scale_1 * (torch.exp(u / 128. * m.log(256)) - 1)
def lin2ulawq(x):
scale = 255.0 / 32768.0
s = torch.sign(x)
x = torch.abs(x)
u = s * (128 * torch.log(1 + scale * x) / m.log(256))
u = torch.clip(128 + torch.round(u), 0, 255)
return u
def lin2ulaw(x):
scale = 255.0 / 32768.0
s = torch.sign(x)
x = torch.abs(x)
u = s * (128 * torch.log(1 + scale * x) / torch.log(256))
u = torch.clip(128 + u, 0, 255)
return u

View file

@ -0,0 +1,14 @@
import wave
def wavwrite16(filename, x, fs):
""" writes x as int16 to file with name filename
If x.dtype is int16 x is written as is. Otherwise,
it is scaled by 2**15 - 1 and converted to int16.
"""
if x.dtype != 'int16':
x = ((2**15 - 1) * x).astype('int16')
with wave.open(filename, 'wb') as f:
f.setparams((1, 2, fs, len(x), 'NONE', ""))
f.writeframes(x.tobytes())