mirror of
https://github.com/xiph/opus.git
synced 2025-05-16 08:28:29 +00:00
added LPCNet torch implementation
Signed-off-by: Jan Buethe <jbuethe@amazon.de>
This commit is contained in:
parent
90a171c1c2
commit
35ee397e06
38 changed files with 3200 additions and 0 deletions
27
dnn/torch/lpcnet/README.md
Normal file
27
dnn/torch/lpcnet/README.md
Normal file
|
@ -0,0 +1,27 @@
|
|||
# LPCNet
|
||||
|
||||
Incomplete pytorch implementation of LPCNet
|
||||
|
||||
## Data preparation
|
||||
For data preparation use dump_data in github.com/xiph/LPCNet. To turn this into
|
||||
a training dataset, copy data and feature file to a folder and run
|
||||
|
||||
python add_dataset_config.py my_dataset_folder
|
||||
|
||||
|
||||
## Training
|
||||
To train a model, create and adjust a setup file, e.g. with
|
||||
|
||||
python make_default_setup.py my_setup.yml --path2dataset my_dataset_folder
|
||||
|
||||
Then simply run
|
||||
|
||||
python train_lpcnet.py my_setup.yml my_output
|
||||
|
||||
## Inference
|
||||
Create feature file with dump_data from github.com/xiph/LPCNet. Then run e.g.
|
||||
|
||||
python test_lpcnet.py features.f32 my_output/checkpoints/checkpoint_ep_10.pth out.wav
|
||||
|
||||
Inference runs on CPU and takes usually between 3 and 20 seconds per generated second of audio,
|
||||
depending on the CPU.
|
48
dnn/torch/lpcnet/add_dataset_config.py
Normal file
48
dnn/torch/lpcnet/add_dataset_config.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
from utils.templates import dataset_template_v1, dataset_template_v2
|
||||
|
||||
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser("add_dataset_config.py")
|
||||
|
||||
parser.add_argument('path', type=str, help='path to folder containing feature and data file')
|
||||
parser.add_argument('--version', type=int, help="dataset version, 1 for classic LPCNet with 55 feature slots, 2 for new format with 36 feature slots.", default=2)
|
||||
parser.add_argument('--description', type=str, help='brief dataset description', default="I will add a description later")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if args.version == 1:
|
||||
template = dataset_template_v1
|
||||
data_extension = '.u8'
|
||||
elif args.version == 2:
|
||||
template = dataset_template_v2
|
||||
data_extension = '.s16'
|
||||
else:
|
||||
raise ValueError(f"unknown dataset version {args.version}")
|
||||
|
||||
# get folder content
|
||||
content = os.listdir(args.path)
|
||||
|
||||
features = [c for c in content if c.endswith('.f32')]
|
||||
|
||||
if len(features) != 1:
|
||||
print("could not determine feature file")
|
||||
else:
|
||||
template['feature_file'] = features[0]
|
||||
|
||||
data = [c for c in content if c.endswith(data_extension)]
|
||||
if len(data) != 1:
|
||||
print("could not determine data file")
|
||||
else:
|
||||
template['signal_file'] = data[0]
|
||||
|
||||
template['description'] = args.description
|
||||
|
||||
with open(os.path.join(args.path, 'info.yml'), 'w') as f:
|
||||
yaml.dump(template, f)
|
1
dnn/torch/lpcnet/data/__init__.py
Normal file
1
dnn/torch/lpcnet/data/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from .lpcnet_dataset import LPCNetDataset
|
198
dnn/torch/lpcnet/data/lpcnet_dataset.py
Normal file
198
dnn/torch/lpcnet/data/lpcnet_dataset.py
Normal file
|
@ -0,0 +1,198 @@
|
|||
""" Dataset for LPCNet training """
|
||||
import os
|
||||
|
||||
import yaml
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
scale = 255.0/32768.0
|
||||
scale_1 = 32768.0/255.0
|
||||
def ulaw2lin(u):
|
||||
u = u - 128
|
||||
s = np.sign(u)
|
||||
u = np.abs(u)
|
||||
return s*scale_1*(np.exp(u/128.*np.log(256))-1)
|
||||
|
||||
|
||||
def lin2ulaw(x):
|
||||
s = np.sign(x)
|
||||
x = np.abs(x)
|
||||
u = (s*(128*np.log(1+scale*x)/np.log(256)))
|
||||
u = np.clip(128 + np.round(u), 0, 255)
|
||||
return u
|
||||
|
||||
|
||||
def run_lpc(signal, lpcs, frame_length=160):
|
||||
num_frames, lpc_order = lpcs.shape
|
||||
|
||||
prediction = np.concatenate(
|
||||
[- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]
|
||||
)
|
||||
error = signal[lpc_order :] - prediction
|
||||
|
||||
return prediction, error
|
||||
|
||||
class LPCNetDataset(Dataset):
|
||||
def __init__(self,
|
||||
path_to_dataset,
|
||||
features=['cepstrum', 'periods', 'pitch_corr'],
|
||||
input_signals=['last_signal', 'prediction', 'last_error'],
|
||||
target='error',
|
||||
frames_per_sample=15,
|
||||
feature_history=2,
|
||||
feature_lookahead=2,
|
||||
lpc_gamma=1):
|
||||
|
||||
super(LPCNetDataset, self).__init__()
|
||||
|
||||
# load dataset info
|
||||
self.path_to_dataset = path_to_dataset
|
||||
with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f:
|
||||
dataset = yaml.load(f, yaml.FullLoader)
|
||||
|
||||
# dataset version
|
||||
self.version = dataset['version']
|
||||
if self.version == 1:
|
||||
self.getitem = self.getitem_v1
|
||||
elif self.version == 2:
|
||||
self.getitem = self.getitem_v2
|
||||
else:
|
||||
raise ValueError(f"dataset version {self.version} unknown")
|
||||
|
||||
# features
|
||||
self.feature_history = feature_history
|
||||
self.feature_lookahead = feature_lookahead
|
||||
self.frame_offset = 1 + self.feature_history
|
||||
self.frames_per_sample = frames_per_sample
|
||||
self.input_features = features
|
||||
self.feature_frame_layout = dataset['feature_frame_layout']
|
||||
self.lpc_gamma = lpc_gamma
|
||||
|
||||
# load feature file
|
||||
self.feature_file = os.path.join(path_to_dataset, dataset['feature_file'])
|
||||
self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype'])
|
||||
self.feature_frame_length = dataset['feature_frame_length']
|
||||
|
||||
assert len(self.features) % self.feature_frame_length == 0
|
||||
self.features = self.features.reshape((-1, self.feature_frame_length))
|
||||
|
||||
# derive number of samples is dataset
|
||||
self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1) // self.frames_per_sample
|
||||
|
||||
# signals
|
||||
self.frame_length = dataset['frame_length']
|
||||
self.signal_frame_layout = dataset['signal_frame_layout']
|
||||
self.input_signals = input_signals
|
||||
self.target = target
|
||||
|
||||
# load signals
|
||||
self.signal_file = os.path.join(path_to_dataset, dataset['signal_file'])
|
||||
self.signals = np.memmap(self.signal_file, dtype=dataset['signal_dtype'])
|
||||
self.signal_frame_length = dataset['signal_frame_length']
|
||||
self.signals = self.signals.reshape((-1, self.signal_frame_length))
|
||||
assert len(self.signals) == len(self.features) * self.frame_length
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.getitem(index)
|
||||
|
||||
def getitem_v2(self, index):
|
||||
sample = dict()
|
||||
|
||||
# extract features
|
||||
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
|
||||
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
|
||||
|
||||
for feature in self.input_features:
|
||||
feature_start, feature_stop = self.feature_frame_layout[feature]
|
||||
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
|
||||
|
||||
# convert periods
|
||||
if 'periods' in self.input_features:
|
||||
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
|
||||
|
||||
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
|
||||
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
|
||||
|
||||
# last_signal and signal are always expected to be there
|
||||
sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
|
||||
sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']]
|
||||
|
||||
# calculate prediction and error if lpc coefficients present and prediction not given
|
||||
if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout:
|
||||
# lpc coefficients with one frame lookahead
|
||||
# frame positions (start one frame early for past excitation)
|
||||
frame_start = self.frame_offset + self.frames_per_sample * index - 1
|
||||
frame_stop = self.frame_offset + self.frames_per_sample * (index + 1)
|
||||
|
||||
# feature positions
|
||||
lpc_start, lpc_stop = self.feature_frame_layout['lpc']
|
||||
lpc_order = lpc_stop - lpc_start
|
||||
lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop]
|
||||
|
||||
# LPC weighting
|
||||
lpc_order = lpc_stop - lpc_start
|
||||
weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)])
|
||||
lpcs = lpcs * weights
|
||||
|
||||
# signal position (lpc_order samples as history)
|
||||
signal_start = frame_start * self.frame_length - lpc_order + 1
|
||||
signal_stop = frame_stop * self.frame_length + 1
|
||||
noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
|
||||
clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']]
|
||||
|
||||
noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length)
|
||||
|
||||
# extract signals
|
||||
offset = self.frame_length
|
||||
sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample]
|
||||
sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample]
|
||||
# calculate error between real signal and noisy prediction
|
||||
|
||||
|
||||
sample['error'] = sample['signal'] - sample['prediction']
|
||||
|
||||
|
||||
# concatenate features
|
||||
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
|
||||
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
|
||||
signals = torch.cat([torch.LongTensor(lin2ulaw(sample[key])).unsqueeze(-1) for key in self.input_signals], dim=-1)
|
||||
target = torch.LongTensor(lin2ulaw(sample[self.target]))
|
||||
periods = torch.LongTensor(sample['periods'])
|
||||
|
||||
return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
|
||||
|
||||
def getitem_v1(self, index):
|
||||
sample = dict()
|
||||
|
||||
# extract features
|
||||
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
|
||||
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
|
||||
|
||||
for feature in self.input_features:
|
||||
feature_start, feature_stop = self.feature_frame_layout[feature]
|
||||
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
|
||||
|
||||
# convert periods
|
||||
if 'periods' in self.input_features:
|
||||
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
|
||||
|
||||
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
|
||||
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
|
||||
|
||||
# last_signal and signal are always expected to be there
|
||||
for signal_name, index in self.signal_frame_layout.items():
|
||||
sample[signal_name] = self.signals[signal_start : signal_stop, index]
|
||||
|
||||
# concatenate features
|
||||
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
|
||||
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
|
||||
signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1)
|
||||
target = torch.LongTensor(sample[self.target])
|
||||
periods = torch.LongTensor(sample['periods'])
|
||||
|
||||
return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset_length
|
112
dnn/torch/lpcnet/engine/lpcnet_engine.py
Normal file
112
dnn/torch/lpcnet/engine/lpcnet_engine.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
import torch
|
||||
from tqdm import tqdm
|
||||
import sys
|
||||
|
||||
def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
|
||||
|
||||
model.to(device)
|
||||
model.train()
|
||||
|
||||
running_loss = 0
|
||||
previous_running_loss = 0
|
||||
|
||||
# gru states
|
||||
gru_a_state = torch.zeros(1, dataloader.batch_size, model.gru_a_units, device=device).to(device)
|
||||
gru_b_state = torch.zeros(1, dataloader.batch_size, model.gru_b_units, device=device).to(device)
|
||||
gru_states = [gru_a_state, gru_b_state]
|
||||
|
||||
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||
|
||||
for i, batch in enumerate(tepoch):
|
||||
|
||||
# set gradients to zero
|
||||
optimizer.zero_grad()
|
||||
|
||||
# zero out initial gru states
|
||||
gru_a_state.zero_()
|
||||
gru_b_state.zero_()
|
||||
|
||||
# push batch to device
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device)
|
||||
|
||||
target = batch['target']
|
||||
|
||||
# calculate model output
|
||||
output = model(batch['features'], batch['periods'], batch['signals'], gru_states)
|
||||
|
||||
# calculate loss
|
||||
loss = criterion(output.permute(0, 2, 1), target)
|
||||
|
||||
# calculate gradients
|
||||
loss.backward()
|
||||
|
||||
# update weights
|
||||
optimizer.step()
|
||||
|
||||
# update learning rate
|
||||
scheduler.step()
|
||||
|
||||
# call sparsifier
|
||||
model.sparsify()
|
||||
|
||||
# update running loss
|
||||
running_loss += float(loss.cpu())
|
||||
|
||||
# update status bar
|
||||
if i % log_interval == 0:
|
||||
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
|
||||
previous_running_loss = running_loss
|
||||
|
||||
|
||||
running_loss /= len(dataloader)
|
||||
|
||||
return running_loss
|
||||
|
||||
def evaluate(model, criterion, dataloader, device, log_interval=10):
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
running_loss = 0
|
||||
previous_running_loss = 0
|
||||
|
||||
# gru states
|
||||
gru_a_state = torch.zeros(1, dataloader.batch_size, model.gru_a_units, device=device).to(device)
|
||||
gru_b_state = torch.zeros(1, dataloader.batch_size, model.gru_b_units, device=device).to(device)
|
||||
gru_states = [gru_a_state, gru_b_state]
|
||||
|
||||
with torch.no_grad():
|
||||
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||
|
||||
for i, batch in enumerate(tepoch):
|
||||
|
||||
|
||||
# zero out initial gru states
|
||||
gru_a_state.zero_()
|
||||
gru_b_state.zero_()
|
||||
|
||||
# push batch to device
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device)
|
||||
|
||||
target = batch['target']
|
||||
|
||||
# calculate model output
|
||||
output = model(batch['features'], batch['periods'], batch['signals'], gru_states)
|
||||
|
||||
# calculate loss
|
||||
loss = criterion(output.permute(0, 2, 1), target)
|
||||
|
||||
# update running loss
|
||||
running_loss += float(loss.cpu())
|
||||
|
||||
# update status bar
|
||||
if i % log_interval == 0:
|
||||
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
|
||||
previous_running_loss = running_loss
|
||||
|
||||
|
||||
running_loss /= len(dataloader)
|
||||
|
||||
return running_loss
|
27
dnn/torch/lpcnet/make_default_setup.py
Normal file
27
dnn/torch/lpcnet/make_default_setup.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
import argparse
|
||||
|
||||
import yaml
|
||||
|
||||
from utils.templates import setup_dict
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('name', type=str, help='name of default setup file')
|
||||
parser.add_argument('--model', choices=['lpcnet', 'multi_rate'], help='LPCNet model name', default='lpcnet')
|
||||
parser.add_argument('--path2dataset', type=str, help='dataset path', default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
setup = setup_dict[args.model]
|
||||
|
||||
# update dataset if given
|
||||
if type(args.path2dataset) != type(None):
|
||||
setup['dataset'] = args.path2dataset
|
||||
|
||||
name = args.name
|
||||
if not name.endswith('.yml'):
|
||||
name += '.yml'
|
||||
|
||||
if __name__ == '__main__':
|
||||
with open(name, 'w') as f:
|
||||
f.write(yaml.dump(setup))
|
49
dnn/torch/lpcnet/make_test_config.py
Normal file
49
dnn/torch/lpcnet/make_test_config.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("config_name", type=str, help="name of config file (.yml will be appended)")
|
||||
parser.add_argument("test_name", type=str, help="name for test result display")
|
||||
parser.add_argument("checkpoint", type=str, help="checkpoint to test")
|
||||
parser.add_argument("--lpcnet-demo", type=str, help="path to lpcnet_demo binary, default: /local/code/LPCNet/lpcnet_demo", default="/local/code/LPCNet/lpcnet_demo")
|
||||
parser.add_argument("--lpcnext-path", type=str, help="path to lpcnext folder, defalut: dirname(__file__)", default=os.path.dirname(__file__))
|
||||
parser.add_argument("--python-exe", type=str, help='python executable path, default: sys.executable', default=sys.executable)
|
||||
parser.add_argument("--pad", type=str, help="left pad of output in seconds, default: 0.015", default="0.015")
|
||||
parser.add_argument("--trim", type=str, help="left trim of output in seconds, default: 0", default="0")
|
||||
|
||||
|
||||
|
||||
template='''
|
||||
test: "{NAME}"
|
||||
processing:
|
||||
- "sox {{INPUT}} {{INPUT}}.raw"
|
||||
- "{LPCNET_DEMO} -features {{INPUT}}.raw {{INPUT}}.features.f32"
|
||||
- "{PYTHON} {WORKING}/test_lpcnet.py {{INPUT}}.features.f32 {CHECKPOINT} {{OUTPUT}}.ua.wav"
|
||||
- "sox {{OUTPUT}}.ua.wav {{OUTPUT}}.uap.wav pad {PAD}"
|
||||
- "sox {{OUTPUT}}.uap.wav {{OUTPUT}} trim {TRIM}"
|
||||
- "rm {{INPUT}}.raw {{OUTPUT}}.uap.wav {{OUTPUT}}.ua.wav {{INPUT}}.features.f32"
|
||||
'''
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
file_content = template.format(
|
||||
NAME=args.test_name,
|
||||
LPCNET_DEMO=os.path.abspath(args.lpcnet_demo),
|
||||
PYTHON=os.path.abspath(args.python_exe),
|
||||
PAD=args.pad,
|
||||
TRIM=args.trim,
|
||||
WORKING=os.path.abspath(args.lpcnext_path),
|
||||
CHECKPOINT=os.path.abspath(args.checkpoint)
|
||||
)
|
||||
|
||||
print(file_content)
|
||||
|
||||
filename = args.config_name
|
||||
if not filename.endswith(".yml"):
|
||||
filename += ".yml"
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write(file_content)
|
8
dnn/torch/lpcnet/models/__init__.py
Normal file
8
dnn/torch/lpcnet/models/__init__.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
from .lpcnet import LPCNet
|
||||
from .multi_rate_lpcnet import MultiRateLPCNet
|
||||
|
||||
|
||||
model_dict = {
|
||||
'lpcnet' : LPCNet,
|
||||
'multi_rate' : MultiRateLPCNet
|
||||
}
|
274
dnn/torch/lpcnet/models/lpcnet.py
Normal file
274
dnn/torch/lpcnet/models/lpcnet.py
Normal file
|
@ -0,0 +1,274 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
|
||||
from utils.ulaw import lin2ulawq, ulaw2lin
|
||||
from utils.sample import sample_excitation
|
||||
from utils.pcm import clip_to_int16
|
||||
from utils.sparsification import GRUSparsifier, calculate_gru_flops_per_step
|
||||
from utils.layers import DualFC
|
||||
from utils.misc import get_pdf_from_tree
|
||||
|
||||
|
||||
class LPCNet(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(LPCNet, self).__init__()
|
||||
|
||||
#
|
||||
self.input_layout = config['input_layout']
|
||||
self.feature_history = config['feature_history']
|
||||
self.feature_lookahead = config['feature_lookahead']
|
||||
|
||||
# frame rate network parameters
|
||||
self.feature_dimension = config['feature_dimension']
|
||||
self.period_embedding_dim = config['period_embedding_dim']
|
||||
self.period_levels = config['period_levels']
|
||||
self.feature_channels = self.feature_dimension + self.period_embedding_dim
|
||||
self.feature_conditioning_dim = config['feature_conditioning_dim']
|
||||
self.feature_conv_kernel_size = config['feature_conv_kernel_size']
|
||||
|
||||
|
||||
# frame rate network layers
|
||||
self.period_embedding = nn.Embedding(self.period_levels, self.period_embedding_dim)
|
||||
self.feature_conv1 = nn.Conv1d(self.feature_channels, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
|
||||
self.feature_conv2 = nn.Conv1d(self.feature_conditioning_dim, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
|
||||
self.feature_dense1 = nn.Linear(self.feature_conditioning_dim, self.feature_conditioning_dim)
|
||||
self.feature_dense2 = nn.Linear(*(2*[self.feature_conditioning_dim]))
|
||||
|
||||
# sample rate network parameters
|
||||
self.frame_size = config['frame_size']
|
||||
self.signal_levels = config['signal_levels']
|
||||
self.signal_embedding_dim = config['signal_embedding_dim']
|
||||
self.gru_a_units = config['gru_a_units']
|
||||
self.gru_b_units = config['gru_b_units']
|
||||
self.output_levels = config['output_levels']
|
||||
self.hsampling = config.get('hsampling', False)
|
||||
|
||||
self.gru_a_input_dim = len(self.input_layout['signals']) * self.signal_embedding_dim + self.feature_conditioning_dim
|
||||
self.gru_b_input_dim = self.gru_a_units + self.feature_conditioning_dim
|
||||
|
||||
# sample rate network layers
|
||||
self.signal_embedding = nn.Embedding(self.signal_levels, self.signal_embedding_dim)
|
||||
self.gru_a = nn.GRU(self.gru_a_input_dim, self.gru_a_units, batch_first=True)
|
||||
self.gru_b = nn.GRU(self.gru_b_input_dim, self.gru_b_units, batch_first=True)
|
||||
self.dual_fc = DualFC(self.gru_b_units, self.output_levels)
|
||||
|
||||
# sparsification
|
||||
self.sparsifier = []
|
||||
|
||||
# GRU A
|
||||
if 'gru_a' in config['sparsification']:
|
||||
gru_config = config['sparsification']['gru_a']
|
||||
task_list = [(self.gru_a, gru_config['params'])]
|
||||
self.sparsifier.append(GRUSparsifier(task_list,
|
||||
gru_config['start'],
|
||||
gru_config['stop'],
|
||||
gru_config['interval'],
|
||||
gru_config['exponent'])
|
||||
)
|
||||
self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a,
|
||||
gru_config['params'], drop_input=True)
|
||||
else:
|
||||
self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, drop_input=True)
|
||||
|
||||
# GRU B
|
||||
if 'gru_b' in config['sparsification']:
|
||||
gru_config = config['sparsification']['gru_b']
|
||||
task_list = [(self.gru_b, gru_config['params'])]
|
||||
self.sparsifier.append(GRUSparsifier(task_list,
|
||||
gru_config['start'],
|
||||
gru_config['stop'],
|
||||
gru_config['interval'],
|
||||
gru_config['exponent'])
|
||||
)
|
||||
self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b,
|
||||
gru_config['params'])
|
||||
else:
|
||||
self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b)
|
||||
|
||||
# inference parameters
|
||||
self.lpc_gamma = config.get('lpc_gamma', 1)
|
||||
|
||||
def sparsify(self):
|
||||
for sparsifier in self.sparsifier:
|
||||
sparsifier.step()
|
||||
|
||||
def get_gflops(self, fs, verbose=False):
|
||||
gflops = 0
|
||||
|
||||
# frame rate network
|
||||
conditioning_dim = self.feature_conditioning_dim
|
||||
feature_channels = self.feature_channels
|
||||
frame_rate = fs / self.frame_size
|
||||
frame_rate_network_complexity = 1e-9 * 2 * (5 * conditioning_dim + 3 * feature_channels) * conditioning_dim * frame_rate
|
||||
if verbose:
|
||||
print(f"frame rate network: {frame_rate_network_complexity} GFLOPS")
|
||||
gflops += frame_rate_network_complexity
|
||||
|
||||
# gru a
|
||||
gru_a_rate = fs
|
||||
gru_a_complexity = 1e-9 * gru_a_rate * self.gru_a_flops_per_step
|
||||
if verbose:
|
||||
print(f"gru A: {gru_a_complexity} GFLOPS")
|
||||
gflops += gru_a_complexity
|
||||
|
||||
# gru b
|
||||
gru_b_rate = fs
|
||||
gru_b_complexity = 1e-9 * gru_b_rate * self.gru_b_flops_per_step
|
||||
if verbose:
|
||||
print(f"gru B: {gru_b_complexity} GFLOPS")
|
||||
gflops += gru_b_complexity
|
||||
|
||||
|
||||
# dual fcs
|
||||
fc = self.dual_fc
|
||||
rate = fs
|
||||
input_size = fc.dense1.in_features
|
||||
output_size = fc.dense1.out_features
|
||||
dual_fc_complexity = 1e-9 * (4 * input_size * output_size + 22 * output_size) * rate
|
||||
if self.hsampling:
|
||||
dual_fc_complexity /= 8
|
||||
if verbose:
|
||||
print(f"dual_fc: {dual_fc_complexity} GFLOPS")
|
||||
gflops += dual_fc_complexity
|
||||
|
||||
if verbose:
|
||||
print(f'total: {gflops} GFLOPS')
|
||||
|
||||
return gflops
|
||||
|
||||
def frame_rate_network(self, features, periods):
|
||||
|
||||
embedded_periods = torch.flatten(self.period_embedding(periods), 2, 3)
|
||||
features = torch.concat((features, embedded_periods), dim=-1)
|
||||
|
||||
# convert to channels first and calculate conditioning vector
|
||||
c = torch.permute(features, [0, 2, 1])
|
||||
|
||||
c = torch.tanh(self.feature_conv1(c))
|
||||
c = torch.tanh(self.feature_conv2(c))
|
||||
# back to channels last
|
||||
c = torch.permute(c, [0, 2, 1])
|
||||
c = torch.tanh(self.feature_dense1(c))
|
||||
c = torch.tanh(self.feature_dense2(c))
|
||||
|
||||
return c
|
||||
|
||||
def sample_rate_network(self, signals, c, gru_states):
|
||||
embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
|
||||
c_upsampled = torch.repeat_interleave(c, self.frame_size, dim=1)
|
||||
|
||||
y = torch.concat((embedded_signals, c_upsampled), dim=-1)
|
||||
y, gru_a_state = self.gru_a(y, gru_states[0])
|
||||
y = torch.concat((y, c_upsampled), dim=-1)
|
||||
y, gru_b_state = self.gru_b(y, gru_states[1])
|
||||
|
||||
y = self.dual_fc(y)
|
||||
|
||||
if self.hsampling:
|
||||
y = torch.sigmoid(y)
|
||||
log_probs = torch.log(get_pdf_from_tree(y) + 1e-6)
|
||||
else:
|
||||
log_probs = torch.log_softmax(y, dim=-1)
|
||||
|
||||
return log_probs, (gru_a_state, gru_b_state)
|
||||
|
||||
def decoder(self, signals, c, gru_states):
|
||||
embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
|
||||
|
||||
y = torch.concat((embedded_signals, c), dim=-1)
|
||||
y, gru_a_state = self.gru_a(y, gru_states[0])
|
||||
y = torch.concat((y, c), dim=-1)
|
||||
y, gru_b_state = self.gru_b(y, gru_states[1])
|
||||
|
||||
y = self.dual_fc(y)
|
||||
|
||||
if self.hsampling:
|
||||
y = torch.sigmoid(y)
|
||||
probs = get_pdf_from_tree(y)
|
||||
else:
|
||||
probs = torch.softmax(y, dim=-1)
|
||||
|
||||
return probs, (gru_a_state, gru_b_state)
|
||||
|
||||
def forward(self, features, periods, signals, gru_states):
|
||||
|
||||
c = self.frame_rate_network(features, periods)
|
||||
log_probs, _ = self.sample_rate_network(signals, c, gru_states)
|
||||
|
||||
return log_probs
|
||||
|
||||
def generate(self, features, periods, lpcs):
|
||||
|
||||
with torch.no_grad():
|
||||
device = self.parameters().__next__().device
|
||||
|
||||
num_frames = features.shape[0] - self.feature_history - self.feature_lookahead
|
||||
lpc_order = lpcs.shape[-1]
|
||||
num_input_signals = len(self.input_layout['signals'])
|
||||
pitch_corr_position = self.input_layout['features']['pitch_corr'][0]
|
||||
|
||||
# signal buffers
|
||||
pcm = torch.zeros((num_frames * self.frame_size + lpc_order))
|
||||
output = torch.zeros((num_frames * self.frame_size), dtype=torch.int16)
|
||||
mem = 0
|
||||
|
||||
# state buffers
|
||||
gru_a_state = torch.zeros((1, 1, self.gru_a_units))
|
||||
gru_b_state = torch.zeros((1, 1, self.gru_b_units))
|
||||
gru_states = [gru_a_state, gru_b_state]
|
||||
|
||||
input_signals = torch.zeros((1, 1, num_input_signals), dtype=torch.long) + 128
|
||||
|
||||
# push data to device
|
||||
features = features.to(device)
|
||||
periods = periods.to(device)
|
||||
lpcs = lpcs.to(device)
|
||||
|
||||
# lpc weighting
|
||||
weights = torch.FloatTensor([self.lpc_gamma ** (i + 1) for i in range(lpc_order)]).to(device)
|
||||
lpcs = lpcs * weights
|
||||
|
||||
# run feature encoding
|
||||
c = self.frame_rate_network(features.unsqueeze(0), periods.unsqueeze(0))
|
||||
|
||||
for frame_index in range(num_frames):
|
||||
frame_start = frame_index * self.frame_size
|
||||
pitch_corr = features[frame_index + self.feature_history, pitch_corr_position]
|
||||
a = - torch.flip(lpcs[frame_index + self.feature_history], [0])
|
||||
current_c = c[:, frame_index : frame_index + 1, :]
|
||||
|
||||
for i in range(self.frame_size):
|
||||
pcm_position = frame_start + i + lpc_order
|
||||
output_position = frame_start + i
|
||||
|
||||
# prepare input
|
||||
pred = torch.sum(pcm[pcm_position - lpc_order : pcm_position] * a)
|
||||
if 'prediction' in self.input_layout['signals']:
|
||||
input_signals[0, 0, self.input_layout['signals']['prediction']] = lin2ulawq(pred)
|
||||
|
||||
# run single step of sample rate network
|
||||
probs, gru_states = self.decoder(
|
||||
input_signals,
|
||||
current_c,
|
||||
gru_states
|
||||
)
|
||||
|
||||
# sample from output
|
||||
exc_ulaw = sample_excitation(probs, pitch_corr)
|
||||
|
||||
# signal generation
|
||||
exc = ulaw2lin(exc_ulaw)
|
||||
sig = exc + pred
|
||||
pcm[pcm_position] = sig
|
||||
mem = 0.85 * mem + float(sig)
|
||||
output[output_position] = clip_to_int16(round(mem))
|
||||
|
||||
# buffer update
|
||||
if 'last_signal' in self.input_layout['signals']:
|
||||
input_signals[0, 0, self.input_layout['signals']['last_signal']] = lin2ulawq(sig)
|
||||
|
||||
if 'last_error' in self.input_layout['signals']:
|
||||
input_signals[0, 0, self.input_layout['signals']['last_error']] = lin2ulawq(exc)
|
||||
|
||||
return output
|
408
dnn/torch/lpcnet/models/multi_rate_lpcnet.py
Normal file
408
dnn/torch/lpcnet/models/multi_rate_lpcnet.py
Normal file
|
@ -0,0 +1,408 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from utils.layers.subconditioner import get_subconditioner
|
||||
from utils.layers import DualFC
|
||||
|
||||
from utils.ulaw import lin2ulawq, ulaw2lin
|
||||
from utils.sample import sample_excitation
|
||||
from utils.pcm import clip_to_int16
|
||||
from utils.sparsification import GRUSparsifier, calculate_gru_flops_per_step
|
||||
|
||||
from utils.misc import interleave_tensors
|
||||
|
||||
|
||||
|
||||
|
||||
# MultiRateLPCNet
|
||||
class MultiRateLPCNet(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(MultiRateLPCNet, self).__init__()
|
||||
|
||||
# general parameters
|
||||
self.input_layout = config['input_layout']
|
||||
self.feature_history = config['feature_history']
|
||||
self.feature_lookahead = config['feature_lookahead']
|
||||
self.signals = config['signals']
|
||||
|
||||
# frame rate network parameters
|
||||
self.feature_dimension = config['feature_dimension']
|
||||
self.period_embedding_dim = config['period_embedding_dim']
|
||||
self.period_levels = config['period_levels']
|
||||
self.feature_channels = self.feature_dimension + self.period_embedding_dim
|
||||
self.feature_conditioning_dim = config['feature_conditioning_dim']
|
||||
self.feature_conv_kernel_size = config['feature_conv_kernel_size']
|
||||
|
||||
# frame rate network layers
|
||||
self.period_embedding = nn.Embedding(self.period_levels, self.period_embedding_dim)
|
||||
self.feature_conv1 = nn.Conv1d(self.feature_channels, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
|
||||
self.feature_conv2 = nn.Conv1d(self.feature_conditioning_dim, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
|
||||
self.feature_dense1 = nn.Linear(self.feature_conditioning_dim, self.feature_conditioning_dim)
|
||||
self.feature_dense2 = nn.Linear(*(2*[self.feature_conditioning_dim]))
|
||||
|
||||
# sample rate network parameters
|
||||
self.frame_size = config['frame_size']
|
||||
self.signal_levels = config['signal_levels']
|
||||
self.signal_embedding_dim = config['signal_embedding_dim']
|
||||
self.gru_a_units = config['gru_a_units']
|
||||
self.gru_b_units = config['gru_b_units']
|
||||
self.output_levels = config['output_levels']
|
||||
|
||||
# subconditioning B
|
||||
sub_config = config['subconditioning']['subconditioning_b']
|
||||
self.substeps_b = sub_config['number_of_subsamples']
|
||||
self.subcondition_signals_b = sub_config['signals']
|
||||
self.signals_idx_b = [self.input_layout['signals'][key] for key in sub_config['signals']]
|
||||
method = sub_config['method']
|
||||
kwargs = sub_config['kwargs']
|
||||
if type(kwargs) == type(None):
|
||||
kwargs = dict()
|
||||
|
||||
state_size = self.gru_b_units
|
||||
self.subconditioner_b = get_subconditioner(method,
|
||||
sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'],
|
||||
state_size, self.signal_levels, len(sub_config['signals']),
|
||||
**sub_config['kwargs'])
|
||||
|
||||
# subconditioning A
|
||||
sub_config = config['subconditioning']['subconditioning_a']
|
||||
self.substeps_a = sub_config['number_of_subsamples']
|
||||
self.subcondition_signals_a = sub_config['signals']
|
||||
self.signals_idx_a = [self.input_layout['signals'][key] for key in sub_config['signals']]
|
||||
method = sub_config['method']
|
||||
kwargs = sub_config['kwargs']
|
||||
if type(kwargs) == type(None):
|
||||
kwargs = dict()
|
||||
|
||||
state_size = self.gru_a_units
|
||||
self.subconditioner_a = get_subconditioner(method,
|
||||
sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'],
|
||||
state_size, self.signal_levels, self.substeps_b * len(sub_config['signals']),
|
||||
**sub_config['kwargs'])
|
||||
|
||||
|
||||
# wrap up subconditioning, group_size_gru_a holds the number
|
||||
# of timesteps that are grouped as sample input for GRU A
|
||||
# input and group_size_subcondition_a holds the number of samples that are
|
||||
# grouped as input to pre-GRU B subconditioning
|
||||
self.group_size_gru_a = self.substeps_a * self.substeps_b
|
||||
self.group_size_subcondition_a = self.substeps_b
|
||||
self.gru_a_rate_divider = self.group_size_gru_a
|
||||
self.gru_b_rate_divider = self.substeps_b
|
||||
|
||||
# gru sizes
|
||||
self.gru_a_input_dim = self.group_size_gru_a * len(self.signals) * self.signal_embedding_dim + self.feature_conditioning_dim
|
||||
self.gru_b_input_dim = self.subconditioner_a.get_output_dim(0) + self.feature_conditioning_dim
|
||||
self.signals_idx = [self.input_layout['signals'][key] for key in self.signals]
|
||||
|
||||
# sample rate network layers
|
||||
self.signal_embedding = nn.Embedding(self.signal_levels, self.signal_embedding_dim)
|
||||
self.gru_a = nn.GRU(self.gru_a_input_dim, self.gru_a_units, batch_first=True)
|
||||
self.gru_b = nn.GRU(self.gru_b_input_dim, self.gru_b_units, batch_first=True)
|
||||
|
||||
# sparsification
|
||||
self.sparsifier = []
|
||||
|
||||
# GRU A
|
||||
if 'gru_a' in config['sparsification']:
|
||||
gru_config = config['sparsification']['gru_a']
|
||||
task_list = [(self.gru_a, gru_config['params'])]
|
||||
self.sparsifier.append(GRUSparsifier(task_list,
|
||||
gru_config['start'],
|
||||
gru_config['stop'],
|
||||
gru_config['interval'],
|
||||
gru_config['exponent'])
|
||||
)
|
||||
self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a,
|
||||
gru_config['params'], drop_input=True)
|
||||
else:
|
||||
self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, drop_input=True)
|
||||
|
||||
# GRU B
|
||||
if 'gru_b' in config['sparsification']:
|
||||
gru_config = config['sparsification']['gru_b']
|
||||
task_list = [(self.gru_b, gru_config['params'])]
|
||||
self.sparsifier.append(GRUSparsifier(task_list,
|
||||
gru_config['start'],
|
||||
gru_config['stop'],
|
||||
gru_config['interval'],
|
||||
gru_config['exponent'])
|
||||
)
|
||||
self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b,
|
||||
gru_config['params'])
|
||||
else:
|
||||
self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b)
|
||||
|
||||
|
||||
|
||||
# dual FCs
|
||||
self.dual_fc = []
|
||||
for i in range(self.substeps_b):
|
||||
dim = self.subconditioner_b.get_output_dim(i)
|
||||
self.dual_fc.append(DualFC(dim, self.output_levels))
|
||||
self.add_module(f"dual_fc_{i}", self.dual_fc[-1])
|
||||
|
||||
def get_gflops(self, fs, verbose=False, hierarchical_sampling=False):
|
||||
gflops = 0
|
||||
|
||||
# frame rate network
|
||||
conditioning_dim = self.feature_conditioning_dim
|
||||
feature_channels = self.feature_channels
|
||||
frame_rate = fs / self.frame_size
|
||||
frame_rate_network_complexity = 1e-9 * 2 * (5 * conditioning_dim + 3 * feature_channels) * conditioning_dim * frame_rate
|
||||
if verbose:
|
||||
print(f"frame rate network: {frame_rate_network_complexity} GFLOPS")
|
||||
gflops += frame_rate_network_complexity
|
||||
|
||||
# gru a
|
||||
gru_a_rate = fs / self.group_size_gru_a
|
||||
gru_a_complexity = 1e-9 * gru_a_rate * self.gru_a_flops_per_step
|
||||
if verbose:
|
||||
print(f"gru A: {gru_a_complexity} GFLOPS")
|
||||
gflops += gru_a_complexity
|
||||
|
||||
# subconditioning a
|
||||
subcond_a_rate = fs / self.substeps_b
|
||||
subconditioning_a_complexity = 1e-9 * self.subconditioner_a.get_average_flops_per_step() * subcond_a_rate
|
||||
if verbose:
|
||||
print(f"subconditioning A: {subconditioning_a_complexity} GFLOPS")
|
||||
gflops += subconditioning_a_complexity
|
||||
|
||||
# gru b
|
||||
gru_b_rate = fs / self.substeps_b
|
||||
gru_b_complexity = 1e-9 * gru_b_rate * self.gru_b_flops_per_step
|
||||
if verbose:
|
||||
print(f"gru B: {gru_b_complexity} GFLOPS")
|
||||
gflops += gru_b_complexity
|
||||
|
||||
# subconditioning b
|
||||
subcond_b_rate = fs
|
||||
subconditioning_b_complexity = 1e-9 * self.subconditioner_b.get_average_flops_per_step() * subcond_b_rate
|
||||
if verbose:
|
||||
print(f"subconditioning B: {subconditioning_b_complexity} GFLOPS")
|
||||
gflops += subconditioning_b_complexity
|
||||
|
||||
# dual fcs
|
||||
for i, fc in enumerate(self.dual_fc):
|
||||
rate = fs / len(self.dual_fc)
|
||||
input_size = fc.dense1.in_features
|
||||
output_size = fc.dense1.out_features
|
||||
dual_fc_complexity = 1e-9 * (4 * input_size * output_size + 22 * output_size) * rate
|
||||
if hierarchical_sampling:
|
||||
dual_fc_complexity /= 8
|
||||
if verbose:
|
||||
print(f"dual_fc_{i}: {dual_fc_complexity} GFLOPS")
|
||||
gflops += dual_fc_complexity
|
||||
|
||||
if verbose:
|
||||
print(f'total: {gflops} GFLOPS')
|
||||
|
||||
return gflops
|
||||
|
||||
|
||||
|
||||
def sparsify(self):
|
||||
for sparsifier in self.sparsifier:
|
||||
sparsifier.step()
|
||||
|
||||
def frame_rate_network(self, features, periods):
|
||||
|
||||
embedded_periods = torch.flatten(self.period_embedding(periods), 2, 3)
|
||||
features = torch.concat((features, embedded_periods), dim=-1)
|
||||
|
||||
# convert to channels first and calculate conditioning vector
|
||||
c = torch.permute(features, [0, 2, 1])
|
||||
|
||||
c = torch.tanh(self.feature_conv1(c))
|
||||
c = torch.tanh(self.feature_conv2(c))
|
||||
# back to channels last
|
||||
c = torch.permute(c, [0, 2, 1])
|
||||
c = torch.tanh(self.feature_dense1(c))
|
||||
c = torch.tanh(self.feature_dense2(c))
|
||||
|
||||
return c
|
||||
|
||||
def prepare_signals(self, signals, group_size, signal_idx):
|
||||
""" extracts, delays and groups signals """
|
||||
|
||||
batch_size, sequence_length, num_signals = signals.shape
|
||||
|
||||
# extract signals according to position
|
||||
signals = torch.cat([signals[:, :, i : i + 1] for i in signal_idx],
|
||||
dim=-1)
|
||||
|
||||
# roll back pcm to account for grouping
|
||||
signals = torch.roll(signals, group_size - 1, -2)
|
||||
|
||||
# reshape
|
||||
signals = torch.reshape(signals,
|
||||
(batch_size, sequence_length // group_size, group_size * len(signal_idx)))
|
||||
|
||||
return signals
|
||||
|
||||
|
||||
def sample_rate_network(self, signals, c, gru_states):
|
||||
|
||||
signals_a = self.prepare_signals(signals, self.group_size_gru_a, self.signals_idx)
|
||||
embedded_signals = torch.flatten(self.signal_embedding(signals_a), 2, 3)
|
||||
# features at GRU A rate
|
||||
c_upsampled_a = torch.repeat_interleave(c, self.frame_size // self.gru_a_rate_divider, dim=1)
|
||||
# features at GRU B rate
|
||||
c_upsampled_b = torch.repeat_interleave(c, self.frame_size // self.gru_b_rate_divider, dim=1)
|
||||
|
||||
y = torch.concat((embedded_signals, c_upsampled_a), dim=-1)
|
||||
y, gru_a_state = self.gru_a(y, gru_states[0])
|
||||
# first round of upsampling and subconditioning
|
||||
c_signals_a = self.prepare_signals(signals, self.group_size_subcondition_a, self.signals_idx_a)
|
||||
y = self.subconditioner_a(y, c_signals_a)
|
||||
y = interleave_tensors(y)
|
||||
|
||||
y = torch.concat((y, c_upsampled_b), dim=-1)
|
||||
y, gru_b_state = self.gru_b(y, gru_states[1])
|
||||
c_signals_b = self.prepare_signals(signals, 1, self.signals_idx_b)
|
||||
y = self.subconditioner_b(y, c_signals_b)
|
||||
|
||||
y = [self.dual_fc[i](y[i]) for i in range(self.substeps_b)]
|
||||
y = interleave_tensors(y)
|
||||
|
||||
return y, (gru_a_state, gru_b_state)
|
||||
|
||||
def decoder(self, signals, c, gru_states):
|
||||
embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
|
||||
|
||||
y = torch.concat((embedded_signals, c), dim=-1)
|
||||
y, gru_a_state = self.gru_a(y, gru_states[0])
|
||||
y = torch.concat((y, c), dim=-1)
|
||||
y, gru_b_state = self.gru_b(y, gru_states[1])
|
||||
|
||||
y = self.dual_fc(y)
|
||||
|
||||
return torch.softmax(y, dim=-1), (gru_a_state, gru_b_state)
|
||||
|
||||
def forward(self, features, periods, signals, gru_states):
|
||||
|
||||
c = self.frame_rate_network(features, periods)
|
||||
y, _ = self.sample_rate_network(signals, c, gru_states)
|
||||
log_probs = torch.log_softmax(y, dim=-1)
|
||||
|
||||
return log_probs
|
||||
|
||||
def generate(self, features, periods, lpcs):
|
||||
|
||||
with torch.no_grad():
|
||||
device = self.parameters().__next__().device
|
||||
|
||||
num_frames = features.shape[0] - self.feature_history - self.feature_lookahead
|
||||
lpc_order = lpcs.shape[-1]
|
||||
num_input_signals = len(self.signals)
|
||||
pitch_corr_position = self.input_layout['features']['pitch_corr'][0]
|
||||
|
||||
# signal buffers
|
||||
last_signal = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
|
||||
prediction = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
|
||||
last_error = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
|
||||
output = torch.zeros((num_frames * self.frame_size), dtype=torch.int16)
|
||||
mem = 0
|
||||
|
||||
# state buffers
|
||||
gru_a_state = torch.zeros((1, 1, self.gru_a_units))
|
||||
gru_b_state = torch.zeros((1, 1, self.gru_b_units))
|
||||
|
||||
input_signals = 128 + torch.zeros(self.group_size_gru_a * num_input_signals, dtype=torch.long)
|
||||
# conditioning signals for subconditioner a
|
||||
c_signals_a = 128 + torch.zeros(self.group_size_subcondition_a * len(self.signals_idx_a), dtype=torch.long)
|
||||
# conditioning signals for subconditioner b
|
||||
c_signals_b = 128 + torch.zeros(len(self.signals_idx_b), dtype=torch.long)
|
||||
|
||||
# signal dict
|
||||
signal_dict = {
|
||||
'prediction' : prediction,
|
||||
'last_error' : last_error,
|
||||
'last_signal' : last_signal
|
||||
}
|
||||
|
||||
# push data to device
|
||||
features = features.to(device)
|
||||
periods = periods.to(device)
|
||||
lpcs = lpcs.to(device)
|
||||
|
||||
# run feature encoding
|
||||
c = self.frame_rate_network(features.unsqueeze(0), periods.unsqueeze(0))
|
||||
|
||||
for frame_index in range(num_frames):
|
||||
frame_start = frame_index * self.frame_size
|
||||
pitch_corr = features[frame_index + self.feature_history, pitch_corr_position]
|
||||
a = - torch.flip(lpcs[frame_index + self.feature_history], [0])
|
||||
current_c = c[:, frame_index : frame_index + 1, :]
|
||||
|
||||
for i in range(0, self.frame_size, self.group_size_gru_a):
|
||||
pcm_position = frame_start + i + lpc_order
|
||||
output_position = frame_start + i
|
||||
|
||||
# calculate newest prediction
|
||||
prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a)
|
||||
|
||||
# prepare input
|
||||
for slot in range(self.group_size_gru_a):
|
||||
k = slot - self.group_size_gru_a + 1
|
||||
for idx, name in enumerate(self.signals):
|
||||
input_signals[idx + slot * num_input_signals] = lin2ulawq(
|
||||
signal_dict[name][pcm_position + k]
|
||||
)
|
||||
|
||||
|
||||
# run GRU A
|
||||
embed_signals = self.signal_embedding(input_signals.reshape((1, 1, -1)))
|
||||
embed_signals = torch.flatten(embed_signals, 2)
|
||||
y = torch.cat((embed_signals, current_c), dim=-1)
|
||||
h_a, gru_a_state = self.gru_a(y, gru_a_state)
|
||||
|
||||
# loop over substeps_a
|
||||
for step_a in range(self.substeps_a):
|
||||
# prepare conditioning input
|
||||
for slot in range(self.group_size_subcondition_a):
|
||||
k = slot - self.group_size_subcondition_a + 1
|
||||
for idx, name in enumerate(self.subcondition_signals_a):
|
||||
c_signals_a[idx + slot * num_input_signals] = lin2ulawq(
|
||||
signal_dict[name][pcm_position + k]
|
||||
)
|
||||
|
||||
# subconditioning
|
||||
h_a = self.subconditioner_a.single_step(step_a, h_a, c_signals_a.reshape((1, 1, -1)))
|
||||
|
||||
# run GRU B
|
||||
y = torch.cat((h_a, current_c), dim=-1)
|
||||
h_b, gru_b_state = self.gru_b(y, gru_b_state)
|
||||
|
||||
# loop over substeps b
|
||||
for step_b in range(self.substeps_b):
|
||||
# prepare subconditioning input
|
||||
for idx, name in enumerate(self.subcondition_signals_b):
|
||||
c_signals_b[idx] = lin2ulawq(
|
||||
signal_dict[name][pcm_position]
|
||||
)
|
||||
|
||||
# subcondition
|
||||
h_b = self.subconditioner_b.single_step(step_b, h_b, c_signals_b.reshape((1, 1, -1)))
|
||||
|
||||
# run dual FC
|
||||
probs = torch.softmax(self.dual_fc[step_b](h_b), dim=-1)
|
||||
|
||||
# sample
|
||||
new_exc = ulaw2lin(sample_excitation(probs, pitch_corr))
|
||||
|
||||
# update signals
|
||||
sig = new_exc + prediction[pcm_position]
|
||||
last_error[pcm_position + 1] = new_exc
|
||||
last_signal[pcm_position + 1] = sig
|
||||
|
||||
mem = 0.85 * mem + float(sig)
|
||||
output[output_position] = clip_to_int16(round(mem))
|
||||
|
||||
# increase positions
|
||||
pcm_position += 1
|
||||
output_position += 1
|
||||
|
||||
# calculate next prediction
|
||||
prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a)
|
||||
|
||||
return output
|
35
dnn/torch/lpcnet/print_lpcnet_complexity.py
Normal file
35
dnn/torch/lpcnet/print_lpcnet_complexity.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
import argparse
|
||||
|
||||
import yaml
|
||||
|
||||
from models import model_dict
|
||||
|
||||
|
||||
debug = False
|
||||
if debug:
|
||||
args = type('dummy', (object,),
|
||||
{
|
||||
'setup' : 'setups/lpcnet_m/setup_1_4_concatenative.yml',
|
||||
'hierarchical_sampling' : False
|
||||
})()
|
||||
else:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('setup', type=str, help='setup yaml file')
|
||||
parser.add_argument('--hierarchical-sampling', action="store_true", help='whether to assume hierarchical sampling (default=False)', default=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.setup, 'r') as f:
|
||||
setup = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
# check model
|
||||
if not 'model' in setup['lpcnet']:
|
||||
print(f'warning: did not find model entry in setup, using default lpcnet')
|
||||
model_name = 'lpcnet'
|
||||
else:
|
||||
model_name = setup['lpcnet']['model']
|
||||
|
||||
# create model
|
||||
model = model_dict[model_name](setup['lpcnet']['config'])
|
||||
|
||||
gflops = model.get_gflops(16000, verbose=True, hierarchical_sampling=args.hierarchical_sampling)
|
161
dnn/torch/lpcnet/scripts/collect_multi_run_results.py
Normal file
161
dnn/torch/lpcnet/scripts/collect_multi_run_results.py
Normal file
|
@ -0,0 +1,161 @@
|
|||
import argparse
|
||||
import os
|
||||
from uuid import UUID
|
||||
from collections import OrderedDict
|
||||
import pickle
|
||||
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
import utils
|
||||
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input", type=str, help="input folder containing multi-run output")
|
||||
parser.add_argument("tag", type=str, help="tag for multi-run experiment")
|
||||
parser.add_argument("csv", type=str, help="name for output csv")
|
||||
|
||||
|
||||
def is_uuid(val):
|
||||
try:
|
||||
UUID(val)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def collect_results(folder):
|
||||
|
||||
training_folder = os.path.join(folder, 'training')
|
||||
testing_folder = os.path.join(folder, 'testing')
|
||||
|
||||
# validation loss
|
||||
checkpoint = torch.load(os.path.join(training_folder, 'checkpoints', 'checkpoint_finalize_epoch_1.pth'), map_location='cpu')
|
||||
validation_loss = checkpoint['validation_loss']
|
||||
|
||||
# eval_warpq
|
||||
eval_warpq = utils.data.parse_warpq_scores(os.path.join(training_folder, 'out_finalize.txt'))[-1]
|
||||
|
||||
# testing results
|
||||
testing_results = utils.data.collect_test_stats(os.path.join(testing_folder, 'final'))
|
||||
|
||||
results = OrderedDict()
|
||||
results['eval_loss'] = validation_loss
|
||||
results['eval_warpq'] = eval_warpq
|
||||
results['pesq_mean'] = testing_results['pesq'][0]
|
||||
results['warpq_mean'] = testing_results['warpq'][0]
|
||||
results['pitch_error_mean'] = testing_results['pitch_error'][0]
|
||||
results['voicing_error_mean'] = testing_results['voicing_error'][0]
|
||||
|
||||
return results
|
||||
|
||||
def print_csv(path, results, tag, ranks=None, header=True):
|
||||
|
||||
metrics = next(iter(results.values())).keys()
|
||||
if ranks is not None:
|
||||
rank_keys = next(iter(ranks.values())).keys()
|
||||
else:
|
||||
rank_keys = []
|
||||
|
||||
with open(path, 'w') as f:
|
||||
if header:
|
||||
f.write("uuid, tag")
|
||||
|
||||
for metric in metrics:
|
||||
f.write(f", {metric}")
|
||||
|
||||
for rank in rank_keys:
|
||||
f.write(f", {rank}")
|
||||
|
||||
f.write("\n")
|
||||
|
||||
|
||||
for uuid, values in results.items():
|
||||
f.write(f"{uuid}, {tag}")
|
||||
|
||||
for val in values.values():
|
||||
f.write(f", {val:10.8f}")
|
||||
|
||||
for rank in rank_keys:
|
||||
f.write(f", {ranks[uuid][rank]:4d}")
|
||||
|
||||
f.write("\n")
|
||||
|
||||
def get_ranks(results):
|
||||
|
||||
metrics = list(next(iter(results.values())).keys())
|
||||
|
||||
positive = {'pesq_mean', 'mix'}
|
||||
|
||||
ranks = OrderedDict()
|
||||
for key in results.keys():
|
||||
ranks[key] = OrderedDict()
|
||||
|
||||
for metric in metrics:
|
||||
sign = -1 if metric in positive else 1
|
||||
|
||||
x = sorted([(key, value[metric]) for key, value in results.items()], key=lambda x: sign * x[1])
|
||||
x = [y[0] for y in x]
|
||||
|
||||
for key in results.keys():
|
||||
ranks[key]['rank_' + metric] = x.index(key) + 1
|
||||
|
||||
return ranks
|
||||
|
||||
def analyse_metrics(results):
|
||||
metrics = ['eval_loss', 'pesq_mean', 'warpq_mean', 'pitch_error_mean', 'voicing_error_mean']
|
||||
|
||||
x = []
|
||||
for metric in metrics:
|
||||
x.append([val[metric] for val in results.values()])
|
||||
|
||||
x = np.array(x)
|
||||
|
||||
print(x)
|
||||
|
||||
def add_mix_metric(results):
|
||||
metrics = ['eval_loss', 'pesq_mean', 'warpq_mean', 'pitch_error_mean', 'voicing_error_mean']
|
||||
|
||||
x = []
|
||||
for metric in metrics:
|
||||
x.append([val[metric] for val in results.values()])
|
||||
|
||||
x = np.array(x).transpose() * np.array([-1, 1, -1, -1, -1])
|
||||
|
||||
z = (x - np.mean(x, axis=0)) / np.std(x, axis=0)
|
||||
|
||||
print(f"covariance matrix for normalized scores of {metrics}:")
|
||||
print(np.cov(z.transpose()))
|
||||
|
||||
score = np.mean(z, axis=1)
|
||||
|
||||
for i, key in enumerate(results.keys()):
|
||||
results[key]['mix'] = score[i].item()
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
uuids = sorted([x for x in os.listdir(args.input) if os.path.isdir(os.path.join(args.input, x)) and is_uuid(x)])
|
||||
|
||||
|
||||
results = OrderedDict()
|
||||
|
||||
for uuid in uuids:
|
||||
results[uuid] = collect_results(os.path.join(args.input, uuid))
|
||||
|
||||
|
||||
add_mix_metric(results)
|
||||
|
||||
ranks = get_ranks(results)
|
||||
|
||||
|
||||
|
||||
csv = args.csv if args.csv.endswith('.csv') else args.csv + '.csv'
|
||||
|
||||
print_csv(args.csv, results, args.tag, ranks=ranks)
|
||||
|
||||
|
||||
with open(csv[:-4] + '.pickle', 'wb') as f:
|
||||
pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL)
|
52
dnn/torch/lpcnet/scripts/loop_run.sh
Normal file
52
dnn/torch/lpcnet/scripts/loop_run.sh
Normal file
|
@ -0,0 +1,52 @@
|
|||
#!/bin/bash
|
||||
|
||||
|
||||
case $# in
|
||||
9) SETUP=$1; OUTDIR=$2; NAME=$3; DEVICE=$4; ROUNDS=$5; LPCNEXT=$6; LPCNET=$7; TESTSUITE=$8; TESTITEMS=$9;;
|
||||
*) echo "loop_run.sh setup outdir name device rounds lpcnext_repo lpcnet_repo testsuite_repo testitems"; exit;;
|
||||
esac
|
||||
|
||||
|
||||
PYTHON="/home/ubuntu/opt/miniconda3/envs/torch/bin/python"
|
||||
TESTFEATURES=${LPCNEXT}/testitems/features/all_0_orig_features.f32
|
||||
WARPQREFERENCE=${LPCNEXT}/testitems/wav/all_0_orig.wav
|
||||
METRICS="warpq,pesq,pitch_error,voicing_error"
|
||||
LPCNETDEMO=${LPCNET}/lpcnet_demo
|
||||
|
||||
for ((round = 1; round <= $ROUNDS; round++))
|
||||
do
|
||||
echo
|
||||
echo round $round
|
||||
|
||||
UUID=$(uuidgen)
|
||||
TRAINOUT=${OUTDIR}/${UUID}/training
|
||||
TESTOUT=${OUTDIR}/${UUID}/testing
|
||||
CHECKPOINT=${TRAINOUT}/checkpoints/checkpoint_last.pth
|
||||
FINALCHECKPOINT=${TRAINOUT}/checkpoints/checkpoint_finalize_last.pth
|
||||
|
||||
# run training
|
||||
echo "starting training..."
|
||||
$PYTHON $LPCNEXT/train_lpcnet.py $SETUP $TRAINOUT --device $DEVICE --test-features $TESTFEATURES --warpq-reference $WARPQREFERENCE
|
||||
|
||||
# run finalization
|
||||
echo "starting finalization..."
|
||||
$PYTHON $LPCNEXT/train_lpcnet.py $SETUP $TRAINOUT \
|
||||
--device $DEVICE --test-features $TESTFEATURES \
|
||||
--warpq-reference $WARPQREFERENCE \
|
||||
--finalize --initial-checkpoint $CHECKPOINT
|
||||
|
||||
# create test configs
|
||||
$PYTHON $LPCNEXT/make_test_config.py ${OUTDIR}/${UUID}/testconfig.yml "$NAME $UUID" $CHECKPOINT --lpcnet-demo $LPCNETDEMO
|
||||
$PYTHON $LPCNEXT/make_test_config.py ${OUTDIR}/${UUID}/testconfig_finalize.yml "$NAME $UUID finalized" $FINALCHECKPOINT --lpcnet-demo $LPCNETDEMO
|
||||
|
||||
# run tests
|
||||
echo "starting test 1 (no finalization)..."
|
||||
$PYTHON $TESTSUITE/run_test.py ${OUTDIR}/${UUID}/testconfig.yml \
|
||||
$TESTITEMS ${TESTOUT}/prefinal --num-workers 8 \
|
||||
--num-testitems 400 --metrics $METRICS
|
||||
|
||||
echo "starting test 2 (after finalization)..."
|
||||
$PYTHON $TESTSUITE/run_test.py ${OUTDIR}/${UUID}/testconfig_finalize.yml \
|
||||
$TESTITEMS ${TESTOUT}/final --num-workers 8 \
|
||||
--num-testitems 400 --metrics $METRICS
|
||||
done
|
37
dnn/torch/lpcnet/scripts/make_animation.py
Normal file
37
dnn/torch/lpcnet/scripts/make_animation.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
""" script for creating animations from debug data
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
import sys
|
||||
sys.path.append('./')
|
||||
|
||||
from utils.endoscopy import make_animation, read_data
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('folder', type=str, help='endoscopy folder with debug output')
|
||||
parser.add_argument('output', type=str, help='output file (will be auto-extended with .mp4)')
|
||||
|
||||
parser.add_argument('--start-index', type=int, help='index of first sample to be considered', default=0)
|
||||
parser.add_argument('--stop-index', type=int, help='index of last sample to be considered', default=-1)
|
||||
parser.add_argument('--interval', type=int, help='interval between frames in ms', default=20)
|
||||
parser.add_argument('--half-window-length', type=int, help='half size of window for displaying signals', default=80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
filename = args.output if args.output.endswith('.mp4') else args.output + '.mp4'
|
||||
data = read_data(args.folder)
|
||||
|
||||
make_animation(
|
||||
data,
|
||||
filename,
|
||||
start_index=args.start_index,
|
||||
stop_index = args.stop_index,
|
||||
half_signal_window_length=args.half_window_length
|
||||
)
|
17
dnn/torch/lpcnet/scripts/modify_dataset_target.py
Normal file
17
dnn/torch/lpcnet/scripts/modify_dataset_target.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
import argparse
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="sets s_t to augmented_s_t")
|
||||
|
||||
parser.add_argument('datafile', type=str, help='data.s16 file path')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
data = np.memmap(args.datafile, dtype='int16', mode='readwrite')
|
||||
|
||||
# signal is in data[1::2]
|
||||
# last augmented signal is in data[0::2]
|
||||
|
||||
data[1 : - 1 : 2] = data[2 : : 2]
|
17
dnn/torch/lpcnet/scripts/multi_run.sh
Normal file
17
dnn/torch/lpcnet/scripts/multi_run.sh
Normal file
|
@ -0,0 +1,17 @@
|
|||
#!/bin/bash
|
||||
|
||||
case $# in
|
||||
9) SETUP=$1; OUTDIR=$2; NAME=$3; NUMDEVICES=$4; ROUNDS=$5; LPCNEXT=$6; LPCNET=$7; TESTSUITE=$8; TESTITEMS=$9;;
|
||||
*) echo "multi_run.sh setup outdir name num_devices rounds_per_device lpcnext_repo lpcnet_repo testsuite_repo testitems"; exit;;
|
||||
esac
|
||||
|
||||
|
||||
LOOPRUN=${LPCNEXT}/loop_run.sh
|
||||
|
||||
mkdir -p $OUTDIR
|
||||
|
||||
for ((i = 0; i < $NUMDEVICES; i++))
|
||||
do
|
||||
echo "launching job queue for device $i"
|
||||
nohup bash $LOOPRUN $SETUP $OUTDIR "$NAME" "cuda:$i" $ROUNDS $LPCNEXT $LPCNET $TESTSUITE $TESTITEMS > $OUTDIR/job_${i}_out.txt &
|
||||
done
|
22
dnn/torch/lpcnet/scripts/run_inference_test.sh
Normal file
22
dnn/torch/lpcnet/scripts/run_inference_test.sh
Normal file
|
@ -0,0 +1,22 @@
|
|||
#!/bin/bash
|
||||
|
||||
|
||||
case $# in
|
||||
3) FEATURES=$1; FOLDER=$2; PYTHON=$3;;
|
||||
*) echo "run_inference_test.sh <features file> <output folder> <python path>"; exit;;
|
||||
esac
|
||||
|
||||
|
||||
SCRIPTFOLDER=$(dirname "$0")
|
||||
|
||||
mkdir -p $FOLDER/inference_test
|
||||
|
||||
# update checkpoints
|
||||
for fn in $(find $FOLDER -type f -name "checkpoint*.pth")
|
||||
do
|
||||
tmp=$(basename $fn)
|
||||
tmp=${tmp%.pth}
|
||||
epoch=${tmp#checkpoint_epoch_}
|
||||
echo "running inference with checkpoint $fn..."
|
||||
$PYTHON $SCRIPTFOLDER/../test_lpcnet.py $FEATURES $fn $FOLDER/inference_test/output_epoch_${epoch}.wav
|
||||
done
|
25
dnn/torch/lpcnet/scripts/update_checkpoints.py
Normal file
25
dnn/torch/lpcnet/scripts/update_checkpoints.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
""" script for updating checkpoints with new setup entries
|
||||
|
||||
Use this script to update older outputs with newly introduced
|
||||
parameters. (Saves us the trouble of backward compatibility)
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('checkpoint_file', type=str, help='checkpoint to be updated')
|
||||
parser.add_argument('--model', type=str, help='model update', default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = torch.load(args.checkpoint_file, map_location='cpu')
|
||||
|
||||
# update model entry
|
||||
if type(args.model) != type(None):
|
||||
checkpoint['setup']['lpcnet']['model'] = args.model
|
||||
|
||||
torch.save(checkpoint, args.checkpoint_file)
|
22
dnn/torch/lpcnet/scripts/update_output_folder.sh
Normal file
22
dnn/torch/lpcnet/scripts/update_output_folder.sh
Normal file
|
@ -0,0 +1,22 @@
|
|||
#!/bin/bash
|
||||
|
||||
|
||||
case $# in
|
||||
3) FOLDER=$1; MODEL=$2; PYTHON=$3;;
|
||||
*) echo "update_output_folder.sh folder model python"; exit;;
|
||||
esac
|
||||
|
||||
|
||||
SCRIPTFOLDER=$(dirname "$0")
|
||||
|
||||
|
||||
# update setup
|
||||
echo "updating $FOLDER/setup.py..."
|
||||
$PYTHON $SCRIPTFOLDER/update_setups.py $FOLDER/setup.yml --model $MODEL
|
||||
|
||||
# update checkpoints
|
||||
for fn in $(find $FOLDER -type f -name "checkpoint*.pth")
|
||||
do
|
||||
echo "updating $fn..."
|
||||
$PYTHON $SCRIPTFOLDER/update_checkpoints.py $fn --model $MODEL
|
||||
done
|
28
dnn/torch/lpcnet/scripts/update_setups.py
Normal file
28
dnn/torch/lpcnet/scripts/update_setups.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
""" script for updating setup files with new setup entries
|
||||
|
||||
Use this script to update older outputs with newly introduced
|
||||
parameters. (Saves us the trouble of backward compatibility)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import yaml
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('setup_file', type=str, help='setup to be updated')
|
||||
parser.add_argument('--model', type=str, help='model update', default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# load setup
|
||||
with open(args.setup_file, 'r') as f:
|
||||
setup = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
# update model entry
|
||||
if type(args.model) != type(None):
|
||||
setup['lpcnet']['model'] = args.model
|
||||
|
||||
# dump result
|
||||
with open(args.setup_file, 'w') as f:
|
||||
yaml.dump(setup, f)
|
60
dnn/torch/lpcnet/test_lpcnet.py
Normal file
60
dnn/torch/lpcnet/test_lpcnet.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
from models import model_dict
|
||||
from utils.data import load_features
|
||||
from utils.wav import wavwrite16
|
||||
|
||||
debug = False
|
||||
if debug:
|
||||
args = type('dummy', (object,),
|
||||
{
|
||||
'features' : 'features.f32',
|
||||
'checkpoint' : 'checkpoint.pth',
|
||||
'output' : 'out.wav',
|
||||
'version' : 2
|
||||
})()
|
||||
else:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('features', type=str, help='feature file')
|
||||
parser.add_argument('checkpoint', type=str, help='checkpoint file')
|
||||
parser.add_argument('output', type=str, help='output file')
|
||||
parser.add_argument('--version', type=int, help='feature version', default=2)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
torch.set_num_threads(2)
|
||||
|
||||
version = args.version
|
||||
feature_file = args.features
|
||||
checkpoint_file = args.checkpoint
|
||||
|
||||
|
||||
|
||||
output_file = args.output
|
||||
if not output_file.endswith('.wav'):
|
||||
output_file += '.wav'
|
||||
|
||||
checkpoint = torch.load(checkpoint_file, map_location="cpu")
|
||||
|
||||
# check model
|
||||
if not 'model' in checkpoint['setup']['lpcnet']:
|
||||
print(f'warning: did not find model entry in setup, using default lpcnet')
|
||||
model_name = 'lpcnet'
|
||||
else:
|
||||
model_name = checkpoint['setup']['lpcnet']['model']
|
||||
|
||||
model = model_dict[model_name](checkpoint['setup']['lpcnet']['config'])
|
||||
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
data = load_features(feature_file)
|
||||
|
||||
output = model.generate(data['features'], data['periods'], data['lpcs'])
|
||||
|
||||
wavwrite16(output_file, output.numpy(), 16000)
|
243
dnn/torch/lpcnet/train_lpcnet.py
Normal file
243
dnn/torch/lpcnet/train_lpcnet.py
Normal file
|
@ -0,0 +1,243 @@
|
|||
import os
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
try:
|
||||
import git
|
||||
has_git = True
|
||||
except:
|
||||
has_git = False
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from data import LPCNetDataset
|
||||
from models import model_dict
|
||||
from engine.lpcnet_engine import train_one_epoch, evaluate
|
||||
from utils.data import load_features
|
||||
from utils.wav import wavwrite16
|
||||
|
||||
|
||||
debug = False
|
||||
if debug:
|
||||
args = type('dummy', (object,),
|
||||
{
|
||||
'setup' : 'setup.yml',
|
||||
'output' : 'testout',
|
||||
'device' : None,
|
||||
'test_features' : None,
|
||||
'finalize': False,
|
||||
'initial_checkpoint': None,
|
||||
'no-redirect': False
|
||||
})()
|
||||
else:
|
||||
parser = argparse.ArgumentParser("train_lpcnet.py")
|
||||
parser.add_argument('setup', type=str, help='setup yaml file')
|
||||
parser.add_argument('output', type=str, help='output path')
|
||||
parser.add_argument('--device', type=str, help='compute device', default=None)
|
||||
parser.add_argument('--test-features', type=str, help='test feature file in v2 format', default=None)
|
||||
parser.add_argument('--finalize', action='store_true', help='run single training round with lr=1e-5')
|
||||
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
|
||||
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of output')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
torch.set_num_threads(4)
|
||||
|
||||
with open(args.setup, 'r') as f:
|
||||
setup = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
if args.finalize:
|
||||
if args.initial_checkpoint is None:
|
||||
raise ValueError('finalization requires initial checkpoint')
|
||||
|
||||
if 'sparsification' in setup['lpcnet']['config']:
|
||||
for sp_job in setup['lpcnet']['config']['sparsification'].values():
|
||||
sp_job['start'], sp_job['stop'] = 0, 0
|
||||
|
||||
setup['training']['lr'] = 1.0e-5
|
||||
setup['training']['lr_decay_factor'] = 0.0
|
||||
setup['training']['epochs'] = 1
|
||||
|
||||
checkpoint_prefix = 'checkpoint_finalize'
|
||||
output_prefix = 'output_finalize'
|
||||
setup_name = 'setup_finalize.yml'
|
||||
output_file='out_finalize.txt'
|
||||
else:
|
||||
checkpoint_prefix = 'checkpoint'
|
||||
output_prefix = 'output'
|
||||
setup_name = 'setup.yml'
|
||||
output_file='out.txt'
|
||||
|
||||
|
||||
# check model
|
||||
if not 'model' in setup['lpcnet']:
|
||||
print(f'warning: did not find model entry in setup, using default lpcnet')
|
||||
model_name = 'lpcnet'
|
||||
else:
|
||||
model_name = setup['lpcnet']['model']
|
||||
|
||||
# prepare output folder
|
||||
if os.path.exists(args.output) and not debug and not args.finalize:
|
||||
print("warning: output folder exists")
|
||||
|
||||
reply = input('continue? (y/n): ')
|
||||
while reply not in {'y', 'n'}:
|
||||
reply = input('continue? (y/n): ')
|
||||
|
||||
if reply == 'n':
|
||||
os._exit()
|
||||
else:
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
checkpoint_dir = os.path.join(args.output, 'checkpoints')
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
|
||||
# add repo info to setup
|
||||
if has_git:
|
||||
working_dir = os.path.split(__file__)[0]
|
||||
try:
|
||||
repo = git.Repo(working_dir)
|
||||
setup['repo'] = dict()
|
||||
hash = repo.head.object.hexsha
|
||||
urls = list(repo.remote().urls)
|
||||
is_dirty = repo.is_dirty()
|
||||
|
||||
if is_dirty:
|
||||
print("warning: repo is dirty")
|
||||
|
||||
setup['repo']['hash'] = hash
|
||||
setup['repo']['urls'] = urls
|
||||
setup['repo']['dirty'] = is_dirty
|
||||
except:
|
||||
has_git = False
|
||||
|
||||
# dump setup
|
||||
with open(os.path.join(args.output, setup_name), 'w') as f:
|
||||
yaml.dump(setup, f)
|
||||
|
||||
# prepare inference test if wanted
|
||||
run_inference_test = False
|
||||
if type(args.test_features) != type(None):
|
||||
test_features = load_features(args.test_features)
|
||||
inference_test_dir = os.path.join(args.output, 'inference_test')
|
||||
os.makedirs(inference_test_dir, exist_ok=True)
|
||||
run_inference_test = True
|
||||
|
||||
# training parameters
|
||||
batch_size = setup['training']['batch_size']
|
||||
epochs = setup['training']['epochs']
|
||||
lr = setup['training']['lr']
|
||||
lr_decay_factor = setup['training']['lr_decay_factor']
|
||||
|
||||
# load training dataset
|
||||
lpcnet_config = setup['lpcnet']['config']
|
||||
data = LPCNetDataset( setup['dataset'],
|
||||
features=lpcnet_config['features'],
|
||||
input_signals=lpcnet_config['signals'],
|
||||
target=lpcnet_config['target'],
|
||||
frames_per_sample=setup['training']['frames_per_sample'],
|
||||
feature_history=lpcnet_config['feature_history'],
|
||||
feature_lookahead=lpcnet_config['feature_lookahead'],
|
||||
lpc_gamma=lpcnet_config.get('lpc_gamma', 1))
|
||||
|
||||
# load validation dataset if given
|
||||
if 'validation_dataset' in setup:
|
||||
validation_data = LPCNetDataset( setup['validation_dataset'],
|
||||
features=lpcnet_config['features'],
|
||||
input_signals=lpcnet_config['signals'],
|
||||
target=lpcnet_config['target'],
|
||||
frames_per_sample=setup['training']['frames_per_sample'],
|
||||
feature_history=lpcnet_config['feature_history'],
|
||||
feature_lookahead=lpcnet_config['feature_lookahead'],
|
||||
lpc_gamma=lpcnet_config.get('lpc_gamma', 1))
|
||||
|
||||
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)
|
||||
|
||||
run_validation = True
|
||||
else:
|
||||
run_validation = False
|
||||
|
||||
# create model
|
||||
model = model_dict[model_name](setup['lpcnet']['config'])
|
||||
|
||||
if args.initial_checkpoint is not None:
|
||||
print(f"loading state dict from {args.initial_checkpoint}...")
|
||||
chkpt = torch.load(args.initial_checkpoint, map_location='cpu')
|
||||
model.load_state_dict(chkpt['state_dict'])
|
||||
|
||||
# set compute device
|
||||
if type(args.device) == type(None):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
device = torch.device(args.device)
|
||||
|
||||
# push model to device
|
||||
model.to(device)
|
||||
|
||||
# dataloader
|
||||
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
|
||||
|
||||
# optimizer is introduced to trainable parameters
|
||||
parameters = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.Adam(parameters, lr=lr)
|
||||
|
||||
# learning rate scheduler
|
||||
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
|
||||
|
||||
# loss
|
||||
criterion = torch.nn.NLLLoss()
|
||||
|
||||
# model checkpoint
|
||||
checkpoint = {
|
||||
'setup' : setup,
|
||||
'state_dict' : model.state_dict(),
|
||||
'loss' : -1
|
||||
}
|
||||
|
||||
if not args.no_redirect:
|
||||
print(f"re-directing output to {os.path.join(args.output, output_file)}")
|
||||
sys.stdout = open(os.path.join(args.output, output_file), "w")
|
||||
|
||||
best_loss = 1e9
|
||||
|
||||
for ep in range(1, epochs + 1):
|
||||
print(f"training epoch {ep}...")
|
||||
new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler)
|
||||
|
||||
|
||||
# save checkpoint
|
||||
checkpoint['state_dict'] = model.state_dict()
|
||||
checkpoint['loss'] = new_loss
|
||||
|
||||
if run_validation:
|
||||
print("running validation...")
|
||||
validation_loss = evaluate(model, criterion, validation_dataloader, device)
|
||||
checkpoint['validation_loss'] = validation_loss
|
||||
|
||||
if validation_loss < best_loss:
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth'))
|
||||
best_loss = validation_loss
|
||||
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
|
||||
|
||||
# run inference test
|
||||
if run_inference_test:
|
||||
model.to("cpu")
|
||||
print("running inference test...")
|
||||
|
||||
output = model.generate(test_features['features'], test_features['periods'], test_features['lpcs'])
|
||||
|
||||
testfilename = os.path.join(inference_test_dir, output_prefix + f'_epoch_{ep}.wav')
|
||||
|
||||
wavwrite16(testfilename, output.numpy(), 16000)
|
||||
|
||||
model.to(device)
|
||||
|
||||
print()
|
4
dnn/torch/lpcnet/utils/__init__.py
Normal file
4
dnn/torch/lpcnet/utils/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
from . import sparsification
|
||||
from . import data
|
||||
from . import pcm
|
||||
from . import sample
|
112
dnn/torch/lpcnet/utils/data.py
Normal file
112
dnn/torch/lpcnet/utils/data.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
def load_features(feature_file, version=2):
|
||||
if version == 2:
|
||||
layout = {
|
||||
'cepstrum': [0,18],
|
||||
'periods': [18, 19],
|
||||
'pitch_corr': [19, 20],
|
||||
'lpc': [20, 36]
|
||||
}
|
||||
frame_length = 36
|
||||
|
||||
elif version == 1:
|
||||
layout = {
|
||||
'cepstrum': [0,18],
|
||||
'periods': [36, 37],
|
||||
'pitch_corr': [37, 38],
|
||||
'lpc': [39, 55],
|
||||
}
|
||||
frame_length = 55
|
||||
else:
|
||||
raise ValueError(f'unknown feature version: {version}')
|
||||
|
||||
|
||||
raw_features = torch.from_numpy(np.fromfile(feature_file, dtype='float32'))
|
||||
raw_features = raw_features.reshape((-1, frame_length))
|
||||
|
||||
features = torch.cat(
|
||||
[
|
||||
raw_features[:, layout['cepstrum'][0] : layout['cepstrum'][1]],
|
||||
raw_features[:, layout['pitch_corr'][0] : layout['pitch_corr'][1]]
|
||||
],
|
||||
dim=1
|
||||
)
|
||||
|
||||
lpcs = raw_features[:, layout['lpc'][0] : layout['lpc'][1]]
|
||||
periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long()
|
||||
|
||||
return {'features' : features, 'periods' : periods, 'lpcs' : lpcs}
|
||||
|
||||
|
||||
|
||||
def create_new_data(signal_path, reference_data_path, new_data_path, offset=320, preemph_factor=0.85):
|
||||
ref_data = np.memmap(reference_data_path, dtype=np.int16)
|
||||
signal = np.memmap(signal_path, dtype=np.int16)
|
||||
|
||||
signal_preemph_path = os.path.splitext(signal_path)[0] + '_preemph.raw'
|
||||
signal_preemph = np.memmap(signal_preemph_path, dtype=np.int16, mode='write', shape=signal.shape)
|
||||
|
||||
|
||||
assert len(signal) % 160 == 0
|
||||
num_frames = len(signal) // 160
|
||||
mem = np.zeros(1)
|
||||
for fr in range(len(signal)//160):
|
||||
signal_preemph[fr * 160 : (fr + 1) * 160] = np.convolve(np.concatenate((mem, signal[fr * 160 : (fr + 1) * 160])), [1, -preemph_factor], mode='valid')
|
||||
mem = signal[(fr + 1) * 160 - 1 : (fr + 1) * 160]
|
||||
|
||||
new_data = np.memmap(new_data_path, dtype=np.int16, mode='write', shape=ref_data.shape)
|
||||
|
||||
new_data[:] = 0
|
||||
N = len(signal) - offset
|
||||
new_data[1 : 2*N + 1: 2] = signal_preemph[offset:]
|
||||
new_data[2 : 2*N + 2: 2] = signal_preemph[offset:]
|
||||
|
||||
|
||||
def parse_warpq_scores(output_file):
|
||||
""" extracts warpq scores from output file """
|
||||
|
||||
with open(output_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
scores = [float(line.split("WARP-Q score:")[-1]) for line in lines if line.startswith("WARP-Q score:")]
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def parse_stats_file(file):
|
||||
|
||||
with open(file, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
mean = float(lines[0].split(":")[-1])
|
||||
bt_mean = float(lines[1].split(":")[-1])
|
||||
top_mean = float(lines[2].split(":")[-1])
|
||||
|
||||
return mean, bt_mean, top_mean
|
||||
|
||||
def collect_test_stats(test_folder):
|
||||
""" collects statistics for all discovered metrics from test folder """
|
||||
|
||||
metrics = {'pesq', 'warpq', 'pitch_error', 'voicing_error'}
|
||||
|
||||
results = dict()
|
||||
|
||||
content = os.listdir(test_folder)
|
||||
|
||||
stats_files = [file for file in content if file.startswith('stats_')]
|
||||
|
||||
for file in stats_files:
|
||||
metric = file[len("stats_") : -len(".txt")]
|
||||
|
||||
if metric not in metrics:
|
||||
print(f"warning: unknown metric {metric}")
|
||||
|
||||
mean, bt_mean, top_mean = parse_stats_file(os.path.join(test_folder, file))
|
||||
|
||||
results[metric] = [mean, bt_mean, top_mean]
|
||||
|
||||
return results
|
205
dnn/torch/lpcnet/utils/endoscopy.py
Normal file
205
dnn/torch/lpcnet/utils/endoscopy.py
Normal file
|
@ -0,0 +1,205 @@
|
|||
""" module for inspecting models during inference """
|
||||
|
||||
import os
|
||||
|
||||
import yaml
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.animation as animation
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
|
||||
_state = dict()
|
||||
_folder = 'endoscopy'
|
||||
|
||||
def get_gru_gates(gru, input, state):
|
||||
hidden_size = gru.hidden_size
|
||||
|
||||
direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
|
||||
recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
|
||||
|
||||
# reset gate
|
||||
start, stop = 0 * hidden_size, 1 * hidden_size
|
||||
reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
|
||||
|
||||
# update gate
|
||||
start, stop = 1 * hidden_size, 2 * hidden_size
|
||||
update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
|
||||
|
||||
# new gate
|
||||
start, stop = 2 * hidden_size, 3 * hidden_size
|
||||
new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop]))
|
||||
|
||||
return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
|
||||
|
||||
|
||||
def init(folder='endoscopy'):
|
||||
""" sets up output folder for endoscopy data """
|
||||
|
||||
global _folder
|
||||
_folder = folder
|
||||
|
||||
if not os.path.exists(folder):
|
||||
os.makedirs(folder)
|
||||
else:
|
||||
print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
|
||||
|
||||
def write_data(key, data, fs):
|
||||
""" appends data to previous data written under key """
|
||||
|
||||
global _state
|
||||
|
||||
# convert to numpy if torch.Tensor is given
|
||||
if isinstance(data, torch.Tensor):
|
||||
data = data.detach().numpy()
|
||||
|
||||
if not key in _state:
|
||||
_state[key] = {
|
||||
'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'),
|
||||
'fs' : fs,
|
||||
'dim' : tuple(data.shape),
|
||||
'dtype' : str(data.dtype)
|
||||
}
|
||||
|
||||
with open(os.path.join(_folder, key + '.yml'), 'w') as f:
|
||||
f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
|
||||
else:
|
||||
if _state[key]['fs'] != fs:
|
||||
raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
|
||||
if _state[key]['dtype'] != str(data.dtype):
|
||||
raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
|
||||
if _state[key]['dim'] != tuple(data.shape):
|
||||
raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
|
||||
|
||||
_state[key]['fid'].write(data.tobytes())
|
||||
|
||||
def close(folder='endoscopy'):
|
||||
""" clean up """
|
||||
for key in _state.keys():
|
||||
_state[key]['fid'].close()
|
||||
|
||||
|
||||
def read_data(folder='endoscopy'):
|
||||
""" retrieves written data as numpy arrays """
|
||||
|
||||
|
||||
keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
|
||||
|
||||
return_dict = dict()
|
||||
|
||||
for key in keys:
|
||||
with open(os.path.join(folder, key + '.yml'), 'r') as f:
|
||||
value = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
with open(os.path.join(folder, key + '.bin'), 'rb') as f:
|
||||
data = np.frombuffer(f.read(), dtype=value['dtype'])
|
||||
|
||||
value['data'] = data.reshape((-1,) + value['dim'])
|
||||
|
||||
return_dict[key] = value
|
||||
|
||||
return return_dict
|
||||
|
||||
def get_best_reshape(shape, target_ratio=1):
|
||||
""" calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
|
||||
|
||||
if len(shape) > 1:
|
||||
pixel_count = 1
|
||||
for s in shape:
|
||||
pixel_count *= s
|
||||
else:
|
||||
pixel_count = shape[0]
|
||||
|
||||
if pixel_count == 1:
|
||||
return (1,)
|
||||
|
||||
num_columns = int((pixel_count / target_ratio)**.5)
|
||||
|
||||
while (pixel_count % num_columns):
|
||||
num_columns -= 1
|
||||
|
||||
num_rows = pixel_count // num_columns
|
||||
|
||||
return (num_rows, num_columns)
|
||||
|
||||
def get_type_and_shape(shape):
|
||||
|
||||
# can happen if data is one dimensional
|
||||
if len(shape) == 0:
|
||||
shape = (1,)
|
||||
|
||||
# calculate pixel count
|
||||
if len(shape) > 1:
|
||||
pixel_count = 1
|
||||
for s in shape:
|
||||
pixel_count *= s
|
||||
else:
|
||||
pixel_count = shape[0]
|
||||
|
||||
if pixel_count == 1:
|
||||
return 'plot', (1, )
|
||||
|
||||
# stay with shape if already 2-dimensional
|
||||
if len(shape) == 2:
|
||||
if (shape[0] != pixel_count) or (shape[1] != pixel_count):
|
||||
return 'image', shape
|
||||
|
||||
return 'image', get_best_reshape(shape)
|
||||
|
||||
def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
|
||||
|
||||
# determine plot setup
|
||||
num_keys = len(data.keys())
|
||||
|
||||
num_rows = int((num_keys * 3/4) ** .5)
|
||||
|
||||
num_cols = (num_keys + num_rows - 1) // num_rows
|
||||
|
||||
fig, axs = plt.subplots(num_rows, num_cols)
|
||||
fig.set_size_inches(num_cols * 5, num_rows * 5)
|
||||
|
||||
display = dict()
|
||||
|
||||
fs_max = max([val['fs'] for val in data.values()])
|
||||
|
||||
num_samples = max([val['data'].shape[0] for val in data.values()])
|
||||
|
||||
keys = sorted(data.keys())
|
||||
|
||||
# inspect data
|
||||
for i, key in enumerate(keys):
|
||||
axs[i // num_cols, i % num_cols].title.set_text(key)
|
||||
|
||||
display[key] = dict()
|
||||
|
||||
display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
|
||||
display[key]['down_factor'] = data[key]['fs'] / fs_max
|
||||
|
||||
start_index = max(start_index, half_signal_window_length)
|
||||
while stop_index < 0:
|
||||
stop_index += num_samples
|
||||
|
||||
stop_index = min(stop_index, num_samples - half_signal_window_length)
|
||||
|
||||
# actual plotting
|
||||
frames = []
|
||||
for index in range(start_index, stop_index):
|
||||
ims = []
|
||||
for i, key in enumerate(keys):
|
||||
feature_index = int(round(index * display[key]['down_factor']))
|
||||
|
||||
if display[key]['type'] == 'plot':
|
||||
ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
|
||||
|
||||
elif display[key]['type'] == 'image':
|
||||
ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
|
||||
|
||||
frames.append(ims)
|
||||
|
||||
ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
|
||||
|
||||
if not filename.endswith('.mp4'):
|
||||
filename += '.mp4'
|
||||
|
||||
ani.save(filename)
|
3
dnn/torch/lpcnet/utils/layers/__init__.py
Normal file
3
dnn/torch/lpcnet/utils/layers/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
from .dual_fc import DualFC
|
||||
from .subconditioner import AdditiveSubconditioner, ModulativeSubconditioner, ConcatenativeSubconditioner
|
||||
from .pcm_embeddings import PCMEmbedding, DifferentiablePCMEmbedding
|
15
dnn/torch/lpcnet/utils/layers/dual_fc.py
Normal file
15
dnn/torch/lpcnet/utils/layers/dual_fc.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
class DualFC(nn.Module):
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(DualFC, self).__init__()
|
||||
|
||||
self.dense1 = nn.Linear(input_dim, output_dim)
|
||||
self.dense2 = nn.Linear(input_dim, output_dim)
|
||||
|
||||
self.alpha = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
|
||||
self.beta = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.alpha * torch.tanh(self.dense1(x)) + self.beta * torch.tanh(self.dense2(x))
|
42
dnn/torch/lpcnet/utils/layers/pcm_embeddings.py
Normal file
42
dnn/torch/lpcnet/utils/layers/pcm_embeddings.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
""" module implementing PCM embeddings for LPCNet """
|
||||
|
||||
import math as m
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PCMEmbedding(nn.Module):
|
||||
def __init__(self, embed_dim=128, num_levels=256):
|
||||
super(PCMEmbedding, self).__init__()
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
self.num_levels = num_levels
|
||||
|
||||
self.embedding = nn.Embedding(self.num_levels, self.num_dim)
|
||||
|
||||
# initialize
|
||||
with torch.no_grad():
|
||||
num_rows, num_cols = self.num_levels, self.embed_dim
|
||||
a = m.sqrt(12) * (torch.rand(num_rows, num_cols) - 0.5)
|
||||
for i in range(num_rows):
|
||||
a[i, :] += m.sqrt(12) * (i - num_rows / 2)
|
||||
self.embedding.weight[:, :] = 0.1 * a
|
||||
|
||||
def forward(self, x):
|
||||
return self.embeddint(x)
|
||||
|
||||
|
||||
class DifferentiablePCMEmbedding(PCMEmbedding):
|
||||
def __init__(self, embed_dim, num_levels=256):
|
||||
super(DifferentiablePCMEmbedding, self).__init__(embed_dim, num_levels)
|
||||
|
||||
def forward(self, x):
|
||||
x_int = (x - torch.floor(x)).detach().long()
|
||||
x_frac = x - x_int
|
||||
x_next = torch.minimum(x_int + 1, self.num_levels)
|
||||
|
||||
embed_0 = self.embedding(x_int)
|
||||
embed_1 = self.embedding(x_next)
|
||||
|
||||
return (1 - x_frac) * embed_0 + x_frac * embed_1
|
468
dnn/torch/lpcnet/utils/layers/subconditioner.py
Normal file
468
dnn/torch/lpcnet/utils/layers/subconditioner.py
Normal file
|
@ -0,0 +1,468 @@
|
|||
from re import sub
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
|
||||
|
||||
def get_subconditioner( method,
|
||||
number_of_subsamples,
|
||||
pcm_embedding_size,
|
||||
state_size,
|
||||
pcm_levels,
|
||||
number_of_signals,
|
||||
**kwargs):
|
||||
|
||||
subconditioner_dict = {
|
||||
'additive' : AdditiveSubconditioner,
|
||||
'concatenative' : ConcatenativeSubconditioner,
|
||||
'modulative' : ModulativeSubconditioner
|
||||
}
|
||||
|
||||
return subconditioner_dict[method](number_of_subsamples,
|
||||
pcm_embedding_size, state_size, pcm_levels, number_of_signals, **kwargs)
|
||||
|
||||
|
||||
class Subconditioner(nn.Module):
|
||||
def __init__(self):
|
||||
""" upsampling by subconditioning
|
||||
|
||||
Upsamples a sequence of states conditioning on pcm signals and
|
||||
optionally a feature vector.
|
||||
"""
|
||||
super(Subconditioner, self).__init__()
|
||||
|
||||
def forward(self, states, signals, features=None):
|
||||
raise Exception("Base class should not be called")
|
||||
|
||||
def single_step(self, index, state, signals, features):
|
||||
raise Exception("Base class should not be called")
|
||||
|
||||
def get_output_dim(self, index):
|
||||
raise Exception("Base class should not be called")
|
||||
|
||||
|
||||
class AdditiveSubconditioner(Subconditioner):
|
||||
def __init__(self,
|
||||
number_of_subsamples,
|
||||
pcm_embedding_size,
|
||||
state_size,
|
||||
pcm_levels,
|
||||
number_of_signals,
|
||||
**kwargs):
|
||||
""" subconditioning by addition """
|
||||
|
||||
super(AdditiveSubconditioner, self).__init__()
|
||||
|
||||
self.number_of_subsamples = number_of_subsamples
|
||||
self.pcm_embedding_size = pcm_embedding_size
|
||||
self.state_size = state_size
|
||||
self.pcm_levels = pcm_levels
|
||||
self.number_of_signals = number_of_signals
|
||||
|
||||
if self.pcm_embedding_size != self.state_size:
|
||||
raise ValueError('For additive subconditioning state and embedding '
|
||||
+ f'sizes must match but but got {self.state_size} and {self.pcm_embedding_size}')
|
||||
|
||||
self.embeddings = [None]
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
|
||||
self.add_module('pcm_embedding_' + str(i), embedding)
|
||||
self.embeddings.append(embedding)
|
||||
|
||||
def forward(self, states, signals):
|
||||
""" creates list of subconditioned states
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
states : torch.tensor
|
||||
states of shape (batch, seq_length // s, state_size)
|
||||
signals : torch.tensor
|
||||
signals of shape (batch, seq_length, number_of_signals)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
c_states : list of torch.tensor
|
||||
list of s subconditioned states
|
||||
"""
|
||||
|
||||
s = self.number_of_subsamples
|
||||
|
||||
c_states = [states]
|
||||
new_states = states
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embed = self.embeddings[i](signals[:, i::s])
|
||||
# reduce signal dimension
|
||||
embed = torch.sum(embed, dim=2)
|
||||
|
||||
new_states = new_states + embed
|
||||
c_states.append(new_states)
|
||||
|
||||
return c_states
|
||||
|
||||
def single_step(self, index, state, signals):
|
||||
""" carry out single step for inference
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
index : int
|
||||
position in subconditioning batch
|
||||
|
||||
state : torch.tensor
|
||||
state to sub-condition
|
||||
|
||||
signals : torch.tensor
|
||||
signals for subconditioning, all but the last dimensions
|
||||
must match those of state
|
||||
|
||||
Returns:
|
||||
c_state : torch.tensor
|
||||
subconditioned state
|
||||
"""
|
||||
|
||||
if index == 0:
|
||||
c_state = state
|
||||
else:
|
||||
embed_signals = self.embeddings[index](signals)
|
||||
c = torch.sum(embed_signals, dim=-2)
|
||||
c_state = state + c
|
||||
|
||||
return c_state
|
||||
|
||||
def get_output_dim(self, index):
|
||||
return self.state_size
|
||||
|
||||
def get_average_flops_per_step(self):
|
||||
s = self.number_of_subsamples
|
||||
flops = (s - 1) / s * self.number_of_signals * self.pcm_embedding_size
|
||||
return flops
|
||||
|
||||
|
||||
class ConcatenativeSubconditioner(Subconditioner):
|
||||
def __init__(self,
|
||||
number_of_subsamples,
|
||||
pcm_embedding_size,
|
||||
state_size,
|
||||
pcm_levels,
|
||||
number_of_signals,
|
||||
recurrent=True,
|
||||
**kwargs):
|
||||
""" subconditioning by concatenation """
|
||||
|
||||
super(ConcatenativeSubconditioner, self).__init__()
|
||||
|
||||
self.number_of_subsamples = number_of_subsamples
|
||||
self.pcm_embedding_size = pcm_embedding_size
|
||||
self.state_size = state_size
|
||||
self.pcm_levels = pcm_levels
|
||||
self.number_of_signals = number_of_signals
|
||||
self.recurrent = recurrent
|
||||
|
||||
self.embeddings = []
|
||||
start_index = 0
|
||||
if self.recurrent:
|
||||
start_index = 1
|
||||
self.embeddings.append(None)
|
||||
|
||||
for i in range(start_index, self.number_of_subsamples):
|
||||
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
|
||||
self.add_module('pcm_embedding_' + str(i), embedding)
|
||||
self.embeddings.append(embedding)
|
||||
|
||||
def forward(self, states, signals):
|
||||
""" creates list of subconditioned states
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
states : torch.tensor
|
||||
states of shape (batch, seq_length // s, state_size)
|
||||
signals : torch.tensor
|
||||
signals of shape (batch, seq_length, number_of_signals)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
c_states : list of torch.tensor
|
||||
list of s subconditioned states
|
||||
"""
|
||||
s = self.number_of_subsamples
|
||||
|
||||
if self.recurrent:
|
||||
c_states = [states]
|
||||
start = 1
|
||||
else:
|
||||
c_states = []
|
||||
start = 0
|
||||
|
||||
new_states = states
|
||||
for i in range(start, self.number_of_subsamples):
|
||||
embed = self.embeddings[i](signals[:, i::s])
|
||||
# reduce signal dimension
|
||||
embed = torch.flatten(embed, -2)
|
||||
|
||||
if self.recurrent:
|
||||
new_states = torch.cat((new_states, embed), dim=-1)
|
||||
else:
|
||||
new_states = torch.cat((states, embed), dim=-1)
|
||||
|
||||
c_states.append(new_states)
|
||||
|
||||
return c_states
|
||||
|
||||
def single_step(self, index, state, signals):
|
||||
""" carry out single step for inference
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
index : int
|
||||
position in subconditioning batch
|
||||
|
||||
state : torch.tensor
|
||||
state to sub-condition
|
||||
|
||||
signals : torch.tensor
|
||||
signals for subconditioning, all but the last dimensions
|
||||
must match those of state
|
||||
|
||||
Returns:
|
||||
c_state : torch.tensor
|
||||
subconditioned state
|
||||
"""
|
||||
|
||||
if index == 0 and self.recurrent:
|
||||
c_state = state
|
||||
else:
|
||||
embed_signals = self.embeddings[index](signals)
|
||||
c = torch.flatten(embed_signals, -2)
|
||||
if not self.recurrent and index > 0:
|
||||
# overwrite previous conditioning vector
|
||||
c_state = torch.cat((state[...,:self.state_size], c), dim=-1)
|
||||
else:
|
||||
c_state = torch.cat((state, c), dim=-1)
|
||||
return c_state
|
||||
|
||||
return c_state
|
||||
|
||||
def get_average_flops_per_step(self):
|
||||
return 0
|
||||
|
||||
def get_output_dim(self, index):
|
||||
if self.recurrent:
|
||||
return self.state_size + index * self.pcm_embedding_size * self.number_of_signals
|
||||
else:
|
||||
return self.state_size + self.pcm_embedding_size * self.number_of_signals
|
||||
|
||||
class ModulativeSubconditioner(Subconditioner):
|
||||
def __init__(self,
|
||||
number_of_subsamples,
|
||||
pcm_embedding_size,
|
||||
state_size,
|
||||
pcm_levels,
|
||||
number_of_signals,
|
||||
state_recurrent=False,
|
||||
**kwargs):
|
||||
""" subconditioning by modulation """
|
||||
|
||||
super(ModulativeSubconditioner, self).__init__()
|
||||
|
||||
self.number_of_subsamples = number_of_subsamples
|
||||
self.pcm_embedding_size = pcm_embedding_size
|
||||
self.state_size = state_size
|
||||
self.pcm_levels = pcm_levels
|
||||
self.number_of_signals = number_of_signals
|
||||
self.state_recurrent = state_recurrent
|
||||
|
||||
self.hidden_size = self.pcm_embedding_size * self.number_of_signals
|
||||
|
||||
if self.state_recurrent:
|
||||
self.hidden_size += self.pcm_embedding_size
|
||||
self.state_transform = nn.Linear(self.state_size, self.pcm_embedding_size)
|
||||
|
||||
self.embeddings = [None]
|
||||
self.alphas = [None]
|
||||
self.betas = [None]
|
||||
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
|
||||
self.add_module('pcm_embedding_' + str(i), embedding)
|
||||
self.embeddings.append(embedding)
|
||||
|
||||
self.alphas.append(nn.Linear(self.hidden_size, self.state_size))
|
||||
self.add_module('alpha_dense_' + str(i), self.alphas[-1])
|
||||
|
||||
self.betas.append(nn.Linear(self.hidden_size, self.state_size))
|
||||
self.add_module('beta_dense_' + str(i), self.betas[-1])
|
||||
|
||||
|
||||
|
||||
def forward(self, states, signals):
|
||||
""" creates list of subconditioned states
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
states : torch.tensor
|
||||
states of shape (batch, seq_length // s, state_size)
|
||||
signals : torch.tensor
|
||||
signals of shape (batch, seq_length, number_of_signals)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
c_states : list of torch.tensor
|
||||
list of s subconditioned states
|
||||
"""
|
||||
s = self.number_of_subsamples
|
||||
|
||||
c_states = [states]
|
||||
new_states = states
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embed = self.embeddings[i](signals[:, i::s])
|
||||
# reduce signal dimension
|
||||
embed = torch.flatten(embed, -2)
|
||||
|
||||
if self.state_recurrent:
|
||||
comp_states = self.state_transform(new_states)
|
||||
embed = torch.cat((embed, comp_states), dim=-1)
|
||||
|
||||
alpha = torch.tanh(self.alphas[i](embed))
|
||||
beta = torch.tanh(self.betas[i](embed))
|
||||
|
||||
# new state obtained by modulating previous state
|
||||
new_states = torch.tanh((1 + alpha) * new_states + beta)
|
||||
|
||||
c_states.append(new_states)
|
||||
|
||||
return c_states
|
||||
|
||||
def single_step(self, index, state, signals):
|
||||
""" carry out single step for inference
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
index : int
|
||||
position in subconditioning batch
|
||||
|
||||
state : torch.tensor
|
||||
state to sub-condition
|
||||
|
||||
signals : torch.tensor
|
||||
signals for subconditioning, all but the last dimensions
|
||||
must match those of state
|
||||
|
||||
Returns:
|
||||
c_state : torch.tensor
|
||||
subconditioned state
|
||||
"""
|
||||
|
||||
if index == 0:
|
||||
c_state = state
|
||||
else:
|
||||
embed_signals = self.embeddings[index](signals)
|
||||
c = torch.flatten(embed_signals, -2)
|
||||
if self.state_recurrent:
|
||||
r_state = self.state_transform(state)
|
||||
c = torch.cat((c, r_state), dim=-1)
|
||||
alpha = torch.tanh(self.alphas[index](c))
|
||||
beta = torch.tanh(self.betas[index](c))
|
||||
c_state = torch.tanh((1 + alpha) * state + beta)
|
||||
return c_state
|
||||
|
||||
return c_state
|
||||
|
||||
def get_output_dim(self, index):
|
||||
return self.state_size
|
||||
|
||||
def get_average_flops_per_step(self):
|
||||
s = self.number_of_subsamples
|
||||
|
||||
# estimate activation by 10 flops
|
||||
# c_state = torch.tanh((1 + alpha) * state + beta)
|
||||
flops = 13 * self.state_size
|
||||
|
||||
# hidden size
|
||||
hidden_size = self.number_of_signals * self.pcm_embedding_size
|
||||
if self.state_recurrent:
|
||||
hidden_size += self.pcm_embedding_size
|
||||
|
||||
# counting 2 * A * B flops for Linear(A, B)
|
||||
# alpha = torch.tanh(self.alphas[index](c))
|
||||
# beta = torch.tanh(self.betas[index](c))
|
||||
flops += 4 * hidden_size * self.state_size + 20 * self.state_size
|
||||
|
||||
# r_state = self.state_transform(state)
|
||||
if self.state_recurrent:
|
||||
flops += 2 * self.state_size * self.pcm_embedding_size
|
||||
|
||||
# average over steps
|
||||
flops *= (s - 1) / s
|
||||
|
||||
return flops
|
||||
|
||||
class ComparitiveSubconditioner(Subconditioner):
|
||||
def __init__(self,
|
||||
number_of_subsamples,
|
||||
pcm_embedding_size,
|
||||
state_size,
|
||||
pcm_levels,
|
||||
number_of_signals,
|
||||
error_index=-1,
|
||||
apply_gate=True,
|
||||
normalize=False):
|
||||
""" subconditioning by comparison """
|
||||
|
||||
super(ComparitiveSubconditioner, self).__init__()
|
||||
|
||||
self.comparison_size = self.pcm_embedding_size
|
||||
self.error_position = error_index
|
||||
self.apply_gate = apply_gate
|
||||
self.normalize = normalize
|
||||
|
||||
self.state_transform = nn.Linear(self.state_size, self.comparison_size)
|
||||
|
||||
self.alpha_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
|
||||
self.beta_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
|
||||
|
||||
if self.apply_gate:
|
||||
self.gate_dense = nn.Linear(self.pcm_embedding_size, self.state_size)
|
||||
|
||||
# embeddings and state transforms
|
||||
self.embeddings = [None]
|
||||
self.alpha_denses = [None]
|
||||
self.beta_denses = [None]
|
||||
self.state_transforms = [nn.Linear(self.state_size, self.comparison_size)]
|
||||
self.add_module('state_transform_0', self.state_transforms[0])
|
||||
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
|
||||
self.add_module('pcm_embedding_' + str(i), embedding)
|
||||
self.embeddings.append(embedding)
|
||||
|
||||
state_transform = nn.Linear(self.state_size, self.comparison_size)
|
||||
self.add_module('state_transform_' + str(i), state_transform)
|
||||
self.state_transforms.append(state_transform)
|
||||
|
||||
self.alpha_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
|
||||
self.add_module('alpha_dense_' + str(i), self.alpha_denses[-1])
|
||||
|
||||
self.beta_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
|
||||
self.add_module('beta_dense_' + str(i), self.beta_denses[-1])
|
||||
|
||||
def forward(self, states, signals):
|
||||
s = self.number_of_subsamples
|
||||
|
||||
c_states = [states]
|
||||
new_states = states
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embed = self.embeddings[i](signals[:, i::s])
|
||||
# reduce signal dimension
|
||||
embed = torch.flatten(embed, -2)
|
||||
|
||||
comp_states = self.state_transforms[i](new_states)
|
||||
|
||||
alpha = torch.tanh(self.alpha_dense(embed))
|
||||
beta = torch.tanh(self.beta_dense(embed))
|
||||
|
||||
# new state obtained by modulating previous state
|
||||
new_states = torch.tanh((1 + alpha) * comp_states + beta)
|
||||
|
||||
c_states.append(new_states)
|
||||
|
||||
return c_states
|
36
dnn/torch/lpcnet/utils/misc.py
Normal file
36
dnn/torch/lpcnet/utils/misc.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
import torch
|
||||
|
||||
|
||||
def find(a, v):
|
||||
try:
|
||||
idx = a.index(v)
|
||||
except:
|
||||
idx = -1
|
||||
return idx
|
||||
|
||||
def interleave_tensors(tensors, dim=-2):
|
||||
""" interleave list of tensors along sequence dimension """
|
||||
|
||||
x = torch.cat([x.unsqueeze(dim) for x in tensors], dim=dim)
|
||||
x = torch.flatten(x, dim - 1, dim)
|
||||
|
||||
return x
|
||||
|
||||
def _interleave(x, pcm_levels=256):
|
||||
|
||||
repeats = pcm_levels // (2*x.size(-1))
|
||||
x = x.unsqueeze(-1)
|
||||
p = torch.flatten(torch.repeat_interleave(torch.cat((x, 1 - x), dim=-1), repeats, dim=-1), -2)
|
||||
|
||||
return p
|
||||
|
||||
def get_pdf_from_tree(x):
|
||||
pcm_levels = x.size(-1)
|
||||
|
||||
p = _interleave(x[..., 1:2])
|
||||
n = 4
|
||||
while n <= pcm_levels:
|
||||
p = p * _interleave(x[..., n//2:n])
|
||||
n *= 2
|
||||
|
||||
return p
|
6
dnn/torch/lpcnet/utils/pcm.py
Normal file
6
dnn/torch/lpcnet/utils/pcm.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
|
||||
def clip_to_int16(x):
|
||||
int_min = -2**15
|
||||
int_max = 2**15 - 1
|
||||
x_clipped = max(int_min, min(x, int_max))
|
||||
return x_clipped
|
15
dnn/torch/lpcnet/utils/sample.py
Normal file
15
dnn/torch/lpcnet/utils/sample.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
import torch
|
||||
|
||||
|
||||
def sample_excitation(probs, pitch_corr):
|
||||
|
||||
norm = lambda x : x / (x.sum() + 1e-18)
|
||||
|
||||
# lowering the temperature
|
||||
probs = norm(probs ** (1 + max(0, 1.5 * pitch_corr - 0.5)))
|
||||
# cut-off tails
|
||||
probs = norm(torch.maximum(probs - 0.002 , torch.FloatTensor([0])))
|
||||
# sample
|
||||
exc = torch.multinomial(probs.squeeze(), 1)
|
||||
|
||||
return exc
|
2
dnn/torch/lpcnet/utils/sparsification/__init__.py
Normal file
2
dnn/torch/lpcnet/utils/sparsification/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
from .gru_sparsifier import GRUSparsifier
|
||||
from .common import sparsify_matrix, calculate_gru_flops_per_step
|
92
dnn/torch/lpcnet/utils/sparsification/common.py
Normal file
92
dnn/torch/lpcnet/utils/sparsification/common.py
Normal file
|
@ -0,0 +1,92 @@
|
|||
import torch
|
||||
|
||||
def sparsify_matrix(matrix : torch.tensor, density : float, block_size : list[int, int], keep_diagonal : bool=False, return_mask : bool=False):
|
||||
""" sparsifies matrix with specified block size
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
matrix : torch.tensor
|
||||
matrix to sparsify
|
||||
density : int
|
||||
target density
|
||||
block_size : [int, int]
|
||||
block size dimensions
|
||||
keep_diagonal : bool
|
||||
If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False
|
||||
"""
|
||||
|
||||
m, n = matrix.shape
|
||||
m1, n1 = block_size
|
||||
|
||||
if m % m1 or n % n1:
|
||||
raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}")
|
||||
|
||||
# extract diagonal if keep_diagonal = True
|
||||
if keep_diagonal:
|
||||
if m != n:
|
||||
raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True")
|
||||
|
||||
to_spare = torch.diag(torch.diag(matrix))
|
||||
matrix = matrix - to_spare
|
||||
else:
|
||||
to_spare = torch.zeros_like(matrix)
|
||||
|
||||
# calculate energy in sub-blocks
|
||||
x = torch.reshape(matrix, (m // m1, m1, n // n1, n1))
|
||||
x = x ** 2
|
||||
block_energies = torch.sum(torch.sum(x, dim=3), dim=1)
|
||||
|
||||
number_of_blocks = (m * n) // (m1 * n1)
|
||||
number_of_survivors = round(number_of_blocks * density)
|
||||
|
||||
# masking threshold
|
||||
if number_of_survivors == 0:
|
||||
threshold = 0
|
||||
else:
|
||||
threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors]
|
||||
|
||||
# create mask
|
||||
mask = torch.ones_like(block_energies)
|
||||
mask[block_energies < threshold] = 0
|
||||
mask = torch.repeat_interleave(mask, m1, dim=0)
|
||||
mask = torch.repeat_interleave(mask, n1, dim=1)
|
||||
|
||||
# perform masking
|
||||
masked_matrix = mask * matrix + to_spare
|
||||
|
||||
if return_mask:
|
||||
return masked_matrix, mask
|
||||
else:
|
||||
return masked_matrix
|
||||
|
||||
def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False):
|
||||
input_size = gru.input_size
|
||||
hidden_size = gru.hidden_size
|
||||
flops = 0
|
||||
|
||||
input_density = (
|
||||
sparsification_dict.get('W_ir', [1])[0]
|
||||
+ sparsification_dict.get('W_in', [1])[0]
|
||||
+ sparsification_dict.get('W_iz', [1])[0]
|
||||
) / 3
|
||||
|
||||
recurrent_density = (
|
||||
sparsification_dict.get('W_hr', [1])[0]
|
||||
+ sparsification_dict.get('W_hn', [1])[0]
|
||||
+ sparsification_dict.get('W_hz', [1])[0]
|
||||
) / 3
|
||||
|
||||
# input matrix vector multiplications
|
||||
if not drop_input:
|
||||
flops += 2 * 3 * input_size * hidden_size * input_density
|
||||
|
||||
# recurrent matrix vector multiplications
|
||||
flops += 2 * 3 * hidden_size * hidden_size * recurrent_density
|
||||
|
||||
# biases
|
||||
flops += 6 * hidden_size
|
||||
|
||||
# activations estimated by 10 flops per activation
|
||||
flops += 30 * hidden_size
|
||||
|
||||
return flops
|
158
dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py
Normal file
158
dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py
Normal file
|
@ -0,0 +1,158 @@
|
|||
import torch
|
||||
|
||||
from .common import sparsify_matrix
|
||||
|
||||
|
||||
class GRUSparsifier:
|
||||
def __init__(self, task_list, start, stop, interval, exponent=3):
|
||||
""" Sparsifier for torch.nn.GRUs
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
task_list : list
|
||||
task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance
|
||||
of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in',
|
||||
'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset,
|
||||
update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal),
|
||||
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
|
||||
sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal
|
||||
should be kept.
|
||||
|
||||
start : int
|
||||
training step after which sparsification will be started.
|
||||
|
||||
stop : int
|
||||
training step after which sparsification will be completed.
|
||||
|
||||
interval : int
|
||||
sparsification interval for steps between start and stop. After stop sparsification will be
|
||||
carried out after every call to GRUSparsifier.step()
|
||||
|
||||
exponent : float
|
||||
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
|
||||
with density (alpha + target_density * (1 * alpha)), where
|
||||
alpha = ((stop - i) / (start - stop)) ** exponent
|
||||
|
||||
Example:
|
||||
--------
|
||||
>>> import torch
|
||||
>>> gru = torch.nn.GRU(10, 20)
|
||||
>>> sparsify_dict = {
|
||||
... 'W_ir' : (0.5, [2, 2], False),
|
||||
... 'W_iz' : (0.6, [2, 2], False),
|
||||
... 'W_in' : (0.7, [2, 2], False),
|
||||
... 'W_hr' : (0.1, [4, 4], True),
|
||||
... 'W_hz' : (0.2, [4, 4], True),
|
||||
... 'W_hn' : (0.3, [4, 4], True),
|
||||
... }
|
||||
>>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50)
|
||||
>>> for i in range(100):
|
||||
... sparsifier.step()
|
||||
"""
|
||||
# just copying parameters...
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.interval = interval
|
||||
self.exponent = exponent
|
||||
self.task_list = task_list
|
||||
|
||||
# ... and setting counter to 0
|
||||
self.step_counter = 0
|
||||
|
||||
self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}
|
||||
|
||||
def step(self, verbose=False):
|
||||
""" carries out sparsification step
|
||||
|
||||
Call this function after optimizer.step in your
|
||||
training loop.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
verbose : bool
|
||||
if true, densities are printed out
|
||||
|
||||
Returns:
|
||||
--------
|
||||
None
|
||||
|
||||
"""
|
||||
# compute current interpolation factor
|
||||
self.step_counter += 1
|
||||
|
||||
if self.step_counter < self.start:
|
||||
return
|
||||
elif self.step_counter < self.stop:
|
||||
# update only every self.interval-th interval
|
||||
if self.step_counter % self.interval:
|
||||
return
|
||||
|
||||
alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
|
||||
else:
|
||||
alpha = 0
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
for gru, params in self.task_list:
|
||||
hidden_size = gru.hidden_size
|
||||
|
||||
# input weights
|
||||
for i, key in enumerate(['W_ir', 'W_iz', 'W_in']):
|
||||
if key in params:
|
||||
density = alpha + (1 - alpha) * params[key][0]
|
||||
if verbose:
|
||||
print(f"[{self.step_counter}]: {key} density: {density}")
|
||||
|
||||
gru.weight_ih_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
|
||||
gru.weight_ih_l0[i * hidden_size : (i + 1) * hidden_size, : ],
|
||||
density, # density
|
||||
params[key][1], # block_size
|
||||
params[key][2], # keep_diagonal (might want to set this to False)
|
||||
return_mask=True
|
||||
)
|
||||
|
||||
if type(self.last_masks[key]) != type(None):
|
||||
if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
|
||||
print(f"sparsification mask {key} changed for gru {gru}")
|
||||
|
||||
self.last_masks[key] = new_mask
|
||||
|
||||
# recurrent weights
|
||||
for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']):
|
||||
if key in params:
|
||||
density = alpha + (1 - alpha) * params[key][0]
|
||||
if verbose:
|
||||
print(f"[{self.step_counter}]: {key} density: {density}")
|
||||
gru.weight_hh_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
|
||||
gru.weight_hh_l0[i * hidden_size : (i + 1) * hidden_size, : ],
|
||||
density,
|
||||
params[key][1], # block_size
|
||||
params[key][2], # keep_diagonal (might want to set this to False)
|
||||
return_mask=True
|
||||
)
|
||||
|
||||
if type(self.last_masks[key]) != type(None):
|
||||
if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
|
||||
print(f"sparsification mask {key} changed for gru {gru}")
|
||||
|
||||
self.last_masks[key] = new_mask
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing sparsifier")
|
||||
|
||||
gru = torch.nn.GRU(10, 20)
|
||||
sparsify_dict = {
|
||||
'W_ir' : (0.5, [2, 2], False),
|
||||
'W_iz' : (0.6, [2, 2], False),
|
||||
'W_in' : (0.7, [2, 2], False),
|
||||
'W_hr' : (0.1, [4, 4], True),
|
||||
'W_hz' : (0.2, [4, 4], True),
|
||||
'W_hn' : (0.3, [4, 4], True),
|
||||
}
|
||||
|
||||
sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10)
|
||||
|
||||
for i in range(100):
|
||||
sparsifier.step(verbose=True)
|
128
dnn/torch/lpcnet/utils/templates.py
Normal file
128
dnn/torch/lpcnet/utils/templates.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
from models import multi_rate_lpcnet
|
||||
import copy
|
||||
|
||||
setup_dict = dict()
|
||||
|
||||
dataset_template_v2 = {
|
||||
'version' : 2,
|
||||
'feature_file' : 'features.f32',
|
||||
'signal_file' : 'data.s16',
|
||||
'frame_length' : 160,
|
||||
'feature_frame_length' : 36,
|
||||
'signal_frame_length' : 2,
|
||||
'feature_dtype' : 'float32',
|
||||
'signal_dtype' : 'int16',
|
||||
'feature_frame_layout' : {'cepstrum': [0,18], 'periods': [18, 19], 'pitch_corr': [19, 20], 'lpc': [20, 36]},
|
||||
'signal_frame_layout' : {'last_signal' : 0, 'signal': 1} # signal, last_signal, error, prediction
|
||||
}
|
||||
|
||||
dataset_template_v1 = {
|
||||
'version' : 1,
|
||||
'feature_file' : 'features.f32',
|
||||
'signal_file' : 'data.u8',
|
||||
'frame_length' : 160,
|
||||
'feature_frame_length' : 55,
|
||||
'signal_frame_length' : 4,
|
||||
'feature_dtype' : 'float32',
|
||||
'signal_dtype' : 'uint8',
|
||||
'feature_frame_layout' : {'cepstrum': [0,18], 'periods': [36, 37], 'pitch_corr': [37, 38], 'lpc': [39, 55]},
|
||||
'signal_frame_layout' : {'last_signal' : 0, 'prediction' : 1, 'last_error': 2, 'error': 3} # signal, last_signal, error, prediction
|
||||
}
|
||||
|
||||
# lpcnet
|
||||
|
||||
lpcnet_config = {
|
||||
'frame_size' : 160,
|
||||
'gru_a_units' : 384,
|
||||
'gru_b_units' : 64,
|
||||
'feature_conditioning_dim' : 128,
|
||||
'feature_conv_kernel_size' : 3,
|
||||
'period_levels' : 257,
|
||||
'period_embedding_dim' : 64,
|
||||
'signal_embedding_dim' : 128,
|
||||
'signal_levels' : 256,
|
||||
'feature_dimension' : 19,
|
||||
'output_levels' : 256,
|
||||
'lpc_gamma' : 0.9,
|
||||
'features' : ['cepstrum', 'periods', 'pitch_corr'],
|
||||
'signals' : ['last_signal', 'prediction', 'last_error'],
|
||||
'input_layout' : { 'signals' : {'last_signal' : 0, 'prediction' : 1, 'last_error' : 2},
|
||||
'features' : {'cepstrum' : [0, 18], 'pitch_corr' : [18, 19]} },
|
||||
'target' : 'error',
|
||||
'feature_history' : 2,
|
||||
'feature_lookahead' : 2,
|
||||
'sparsification' : {
|
||||
'gru_a' : {
|
||||
'start' : 10000,
|
||||
'stop' : 30000,
|
||||
'interval' : 100,
|
||||
'exponent' : 3,
|
||||
'params' : {
|
||||
'W_hr' : (0.05, [4, 8], True),
|
||||
'W_hz' : (0.05, [4, 8], True),
|
||||
'W_hn' : (0.2, [4, 8], True)
|
||||
},
|
||||
},
|
||||
'gru_b' : {
|
||||
'start' : 10000,
|
||||
'stop' : 30000,
|
||||
'interval' : 100,
|
||||
'exponent' : 3,
|
||||
'params' : {
|
||||
'W_ir' : (0.5, [4, 8], False),
|
||||
'W_iz' : (0.5, [4, 8], False),
|
||||
'W_in' : (0.5, [4, 8], False)
|
||||
},
|
||||
}
|
||||
},
|
||||
'add_reference_phase' : False,
|
||||
'reference_phase_dim' : 0
|
||||
}
|
||||
|
||||
|
||||
|
||||
# multi rate
|
||||
subconditioning = {
|
||||
'subconditioning_a' : {
|
||||
'number_of_subsamples' : 2,
|
||||
'method' : 'modulative',
|
||||
'signals' : ['last_signal', 'prediction', 'last_error'],
|
||||
'pcm_embedding_size' : 64,
|
||||
'kwargs' : dict()
|
||||
|
||||
},
|
||||
'subconditioning_b' : {
|
||||
'number_of_subsamples' : 2,
|
||||
'method' : 'modulative',
|
||||
'signals' : ['last_signal', 'prediction', 'last_error'],
|
||||
'pcm_embedding_size' : 64,
|
||||
'kwargs' : dict()
|
||||
}
|
||||
}
|
||||
|
||||
multi_rate_lpcnet_config = lpcnet_config.copy()
|
||||
multi_rate_lpcnet_config['subconditioning'] = subconditioning
|
||||
|
||||
training_default = {
|
||||
'batch_size' : 256,
|
||||
'epochs' : 20,
|
||||
'lr' : 1e-3,
|
||||
'lr_decay_factor' : 2.5e-5,
|
||||
'adam_betas' : [0.9, 0.99],
|
||||
'frames_per_sample' : 15
|
||||
}
|
||||
|
||||
lpcnet_setup = {
|
||||
'dataset' : '/local/datasets/lpcnet_training',
|
||||
'lpcnet' : {'config' : lpcnet_config, 'model': 'lpcnet'},
|
||||
'training' : training_default
|
||||
}
|
||||
|
||||
multi_rate_lpcnet_setup = copy.deepcopy(lpcnet_setup)
|
||||
multi_rate_lpcnet_setup['lpcnet']['config'] = multi_rate_lpcnet_config
|
||||
multi_rate_lpcnet_setup['lpcnet']['model'] = 'multi_rate'
|
||||
|
||||
setup_dict = {
|
||||
'lpcnet' : lpcnet_setup,
|
||||
'multi_rate' : multi_rate_lpcnet_setup
|
||||
}
|
29
dnn/torch/lpcnet/utils/ulaw.py
Normal file
29
dnn/torch/lpcnet/utils/ulaw.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
import math as m
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
def ulaw2lin(u):
|
||||
scale_1 = 32768.0 / 255.0
|
||||
u = u - 128
|
||||
s = torch.sign(u)
|
||||
u = torch.abs(u)
|
||||
return s * scale_1 * (torch.exp(u / 128. * m.log(256)) - 1)
|
||||
|
||||
|
||||
def lin2ulawq(x):
|
||||
scale = 255.0 / 32768.0
|
||||
s = torch.sign(x)
|
||||
x = torch.abs(x)
|
||||
u = s * (128 * torch.log(1 + scale * x) / m.log(256))
|
||||
u = torch.clip(128 + torch.round(u), 0, 255)
|
||||
return u
|
||||
|
||||
def lin2ulaw(x):
|
||||
scale = 255.0 / 32768.0
|
||||
s = torch.sign(x)
|
||||
x = torch.abs(x)
|
||||
u = s * (128 * torch.log(1 + scale * x) / torch.log(256))
|
||||
u = torch.clip(128 + u, 0, 255)
|
||||
return u
|
14
dnn/torch/lpcnet/utils/wav.py
Normal file
14
dnn/torch/lpcnet/utils/wav.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
import wave
|
||||
|
||||
def wavwrite16(filename, x, fs):
|
||||
""" writes x as int16 to file with name filename
|
||||
|
||||
If x.dtype is int16 x is written as is. Otherwise,
|
||||
it is scaled by 2**15 - 1 and converted to int16.
|
||||
"""
|
||||
if x.dtype != 'int16':
|
||||
x = ((2**15 - 1) * x).astype('int16')
|
||||
|
||||
with wave.open(filename, 'wb') as f:
|
||||
f.setparams((1, 2, fs, len(x), 'NONE', ""))
|
||||
f.writeframes(x.tobytes())
|
Loading…
Add table
Add a link
Reference in a new issue