mirror of
https://github.com/xiph/opus.git
synced 2025-05-17 00:48:29 +00:00
408 lines
18 KiB
Python
408 lines
18 KiB
Python
import torch
|
|
from torch import nn
|
|
from utils.layers.subconditioner import get_subconditioner
|
|
from utils.layers import DualFC
|
|
|
|
from utils.ulaw import lin2ulawq, ulaw2lin
|
|
from utils.sample import sample_excitation
|
|
from utils.pcm import clip_to_int16
|
|
from utils.sparsification import GRUSparsifier, calculate_gru_flops_per_step
|
|
|
|
from utils.misc import interleave_tensors
|
|
|
|
|
|
|
|
|
|
# MultiRateLPCNet
|
|
class MultiRateLPCNet(nn.Module):
|
|
def __init__(self, config):
|
|
super(MultiRateLPCNet, self).__init__()
|
|
|
|
# general parameters
|
|
self.input_layout = config['input_layout']
|
|
self.feature_history = config['feature_history']
|
|
self.feature_lookahead = config['feature_lookahead']
|
|
self.signals = config['signals']
|
|
|
|
# frame rate network parameters
|
|
self.feature_dimension = config['feature_dimension']
|
|
self.period_embedding_dim = config['period_embedding_dim']
|
|
self.period_levels = config['period_levels']
|
|
self.feature_channels = self.feature_dimension + self.period_embedding_dim
|
|
self.feature_conditioning_dim = config['feature_conditioning_dim']
|
|
self.feature_conv_kernel_size = config['feature_conv_kernel_size']
|
|
|
|
# frame rate network layers
|
|
self.period_embedding = nn.Embedding(self.period_levels, self.period_embedding_dim)
|
|
self.feature_conv1 = nn.Conv1d(self.feature_channels, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
|
|
self.feature_conv2 = nn.Conv1d(self.feature_conditioning_dim, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
|
|
self.feature_dense1 = nn.Linear(self.feature_conditioning_dim, self.feature_conditioning_dim)
|
|
self.feature_dense2 = nn.Linear(*(2*[self.feature_conditioning_dim]))
|
|
|
|
# sample rate network parameters
|
|
self.frame_size = config['frame_size']
|
|
self.signal_levels = config['signal_levels']
|
|
self.signal_embedding_dim = config['signal_embedding_dim']
|
|
self.gru_a_units = config['gru_a_units']
|
|
self.gru_b_units = config['gru_b_units']
|
|
self.output_levels = config['output_levels']
|
|
|
|
# subconditioning B
|
|
sub_config = config['subconditioning']['subconditioning_b']
|
|
self.substeps_b = sub_config['number_of_subsamples']
|
|
self.subcondition_signals_b = sub_config['signals']
|
|
self.signals_idx_b = [self.input_layout['signals'][key] for key in sub_config['signals']]
|
|
method = sub_config['method']
|
|
kwargs = sub_config['kwargs']
|
|
if type(kwargs) == type(None):
|
|
kwargs = dict()
|
|
|
|
state_size = self.gru_b_units
|
|
self.subconditioner_b = get_subconditioner(method,
|
|
sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'],
|
|
state_size, self.signal_levels, len(sub_config['signals']),
|
|
**sub_config['kwargs'])
|
|
|
|
# subconditioning A
|
|
sub_config = config['subconditioning']['subconditioning_a']
|
|
self.substeps_a = sub_config['number_of_subsamples']
|
|
self.subcondition_signals_a = sub_config['signals']
|
|
self.signals_idx_a = [self.input_layout['signals'][key] for key in sub_config['signals']]
|
|
method = sub_config['method']
|
|
kwargs = sub_config['kwargs']
|
|
if type(kwargs) == type(None):
|
|
kwargs = dict()
|
|
|
|
state_size = self.gru_a_units
|
|
self.subconditioner_a = get_subconditioner(method,
|
|
sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'],
|
|
state_size, self.signal_levels, self.substeps_b * len(sub_config['signals']),
|
|
**sub_config['kwargs'])
|
|
|
|
|
|
# wrap up subconditioning, group_size_gru_a holds the number
|
|
# of timesteps that are grouped as sample input for GRU A
|
|
# input and group_size_subcondition_a holds the number of samples that are
|
|
# grouped as input to pre-GRU B subconditioning
|
|
self.group_size_gru_a = self.substeps_a * self.substeps_b
|
|
self.group_size_subcondition_a = self.substeps_b
|
|
self.gru_a_rate_divider = self.group_size_gru_a
|
|
self.gru_b_rate_divider = self.substeps_b
|
|
|
|
# gru sizes
|
|
self.gru_a_input_dim = self.group_size_gru_a * len(self.signals) * self.signal_embedding_dim + self.feature_conditioning_dim
|
|
self.gru_b_input_dim = self.subconditioner_a.get_output_dim(0) + self.feature_conditioning_dim
|
|
self.signals_idx = [self.input_layout['signals'][key] for key in self.signals]
|
|
|
|
# sample rate network layers
|
|
self.signal_embedding = nn.Embedding(self.signal_levels, self.signal_embedding_dim)
|
|
self.gru_a = nn.GRU(self.gru_a_input_dim, self.gru_a_units, batch_first=True)
|
|
self.gru_b = nn.GRU(self.gru_b_input_dim, self.gru_b_units, batch_first=True)
|
|
|
|
# sparsification
|
|
self.sparsifier = []
|
|
|
|
# GRU A
|
|
if 'gru_a' in config['sparsification']:
|
|
gru_config = config['sparsification']['gru_a']
|
|
task_list = [(self.gru_a, gru_config['params'])]
|
|
self.sparsifier.append(GRUSparsifier(task_list,
|
|
gru_config['start'],
|
|
gru_config['stop'],
|
|
gru_config['interval'],
|
|
gru_config['exponent'])
|
|
)
|
|
self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a,
|
|
gru_config['params'], drop_input=True)
|
|
else:
|
|
self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, drop_input=True)
|
|
|
|
# GRU B
|
|
if 'gru_b' in config['sparsification']:
|
|
gru_config = config['sparsification']['gru_b']
|
|
task_list = [(self.gru_b, gru_config['params'])]
|
|
self.sparsifier.append(GRUSparsifier(task_list,
|
|
gru_config['start'],
|
|
gru_config['stop'],
|
|
gru_config['interval'],
|
|
gru_config['exponent'])
|
|
)
|
|
self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b,
|
|
gru_config['params'])
|
|
else:
|
|
self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b)
|
|
|
|
|
|
|
|
# dual FCs
|
|
self.dual_fc = []
|
|
for i in range(self.substeps_b):
|
|
dim = self.subconditioner_b.get_output_dim(i)
|
|
self.dual_fc.append(DualFC(dim, self.output_levels))
|
|
self.add_module(f"dual_fc_{i}", self.dual_fc[-1])
|
|
|
|
def get_gflops(self, fs, verbose=False, hierarchical_sampling=False):
|
|
gflops = 0
|
|
|
|
# frame rate network
|
|
conditioning_dim = self.feature_conditioning_dim
|
|
feature_channels = self.feature_channels
|
|
frame_rate = fs / self.frame_size
|
|
frame_rate_network_complexity = 1e-9 * 2 * (5 * conditioning_dim + 3 * feature_channels) * conditioning_dim * frame_rate
|
|
if verbose:
|
|
print(f"frame rate network: {frame_rate_network_complexity} GFLOPS")
|
|
gflops += frame_rate_network_complexity
|
|
|
|
# gru a
|
|
gru_a_rate = fs / self.group_size_gru_a
|
|
gru_a_complexity = 1e-9 * gru_a_rate * self.gru_a_flops_per_step
|
|
if verbose:
|
|
print(f"gru A: {gru_a_complexity} GFLOPS")
|
|
gflops += gru_a_complexity
|
|
|
|
# subconditioning a
|
|
subcond_a_rate = fs / self.substeps_b
|
|
subconditioning_a_complexity = 1e-9 * self.subconditioner_a.get_average_flops_per_step() * subcond_a_rate
|
|
if verbose:
|
|
print(f"subconditioning A: {subconditioning_a_complexity} GFLOPS")
|
|
gflops += subconditioning_a_complexity
|
|
|
|
# gru b
|
|
gru_b_rate = fs / self.substeps_b
|
|
gru_b_complexity = 1e-9 * gru_b_rate * self.gru_b_flops_per_step
|
|
if verbose:
|
|
print(f"gru B: {gru_b_complexity} GFLOPS")
|
|
gflops += gru_b_complexity
|
|
|
|
# subconditioning b
|
|
subcond_b_rate = fs
|
|
subconditioning_b_complexity = 1e-9 * self.subconditioner_b.get_average_flops_per_step() * subcond_b_rate
|
|
if verbose:
|
|
print(f"subconditioning B: {subconditioning_b_complexity} GFLOPS")
|
|
gflops += subconditioning_b_complexity
|
|
|
|
# dual fcs
|
|
for i, fc in enumerate(self.dual_fc):
|
|
rate = fs / len(self.dual_fc)
|
|
input_size = fc.dense1.in_features
|
|
output_size = fc.dense1.out_features
|
|
dual_fc_complexity = 1e-9 * (4 * input_size * output_size + 22 * output_size) * rate
|
|
if hierarchical_sampling:
|
|
dual_fc_complexity /= 8
|
|
if verbose:
|
|
print(f"dual_fc_{i}: {dual_fc_complexity} GFLOPS")
|
|
gflops += dual_fc_complexity
|
|
|
|
if verbose:
|
|
print(f'total: {gflops} GFLOPS')
|
|
|
|
return gflops
|
|
|
|
|
|
|
|
def sparsify(self):
|
|
for sparsifier in self.sparsifier:
|
|
sparsifier.step()
|
|
|
|
def frame_rate_network(self, features, periods):
|
|
|
|
embedded_periods = torch.flatten(self.period_embedding(periods), 2, 3)
|
|
features = torch.concat((features, embedded_periods), dim=-1)
|
|
|
|
# convert to channels first and calculate conditioning vector
|
|
c = torch.permute(features, [0, 2, 1])
|
|
|
|
c = torch.tanh(self.feature_conv1(c))
|
|
c = torch.tanh(self.feature_conv2(c))
|
|
# back to channels last
|
|
c = torch.permute(c, [0, 2, 1])
|
|
c = torch.tanh(self.feature_dense1(c))
|
|
c = torch.tanh(self.feature_dense2(c))
|
|
|
|
return c
|
|
|
|
def prepare_signals(self, signals, group_size, signal_idx):
|
|
""" extracts, delays and groups signals """
|
|
|
|
batch_size, sequence_length, num_signals = signals.shape
|
|
|
|
# extract signals according to position
|
|
signals = torch.cat([signals[:, :, i : i + 1] for i in signal_idx],
|
|
dim=-1)
|
|
|
|
# roll back pcm to account for grouping
|
|
signals = torch.roll(signals, group_size - 1, -2)
|
|
|
|
# reshape
|
|
signals = torch.reshape(signals,
|
|
(batch_size, sequence_length // group_size, group_size * len(signal_idx)))
|
|
|
|
return signals
|
|
|
|
|
|
def sample_rate_network(self, signals, c, gru_states):
|
|
|
|
signals_a = self.prepare_signals(signals, self.group_size_gru_a, self.signals_idx)
|
|
embedded_signals = torch.flatten(self.signal_embedding(signals_a), 2, 3)
|
|
# features at GRU A rate
|
|
c_upsampled_a = torch.repeat_interleave(c, self.frame_size // self.gru_a_rate_divider, dim=1)
|
|
# features at GRU B rate
|
|
c_upsampled_b = torch.repeat_interleave(c, self.frame_size // self.gru_b_rate_divider, dim=1)
|
|
|
|
y = torch.concat((embedded_signals, c_upsampled_a), dim=-1)
|
|
y, gru_a_state = self.gru_a(y, gru_states[0])
|
|
# first round of upsampling and subconditioning
|
|
c_signals_a = self.prepare_signals(signals, self.group_size_subcondition_a, self.signals_idx_a)
|
|
y = self.subconditioner_a(y, c_signals_a)
|
|
y = interleave_tensors(y)
|
|
|
|
y = torch.concat((y, c_upsampled_b), dim=-1)
|
|
y, gru_b_state = self.gru_b(y, gru_states[1])
|
|
c_signals_b = self.prepare_signals(signals, 1, self.signals_idx_b)
|
|
y = self.subconditioner_b(y, c_signals_b)
|
|
|
|
y = [self.dual_fc[i](y[i]) for i in range(self.substeps_b)]
|
|
y = interleave_tensors(y)
|
|
|
|
return y, (gru_a_state, gru_b_state)
|
|
|
|
def decoder(self, signals, c, gru_states):
|
|
embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
|
|
|
|
y = torch.concat((embedded_signals, c), dim=-1)
|
|
y, gru_a_state = self.gru_a(y, gru_states[0])
|
|
y = torch.concat((y, c), dim=-1)
|
|
y, gru_b_state = self.gru_b(y, gru_states[1])
|
|
|
|
y = self.dual_fc(y)
|
|
|
|
return torch.softmax(y, dim=-1), (gru_a_state, gru_b_state)
|
|
|
|
def forward(self, features, periods, signals, gru_states):
|
|
|
|
c = self.frame_rate_network(features, periods)
|
|
y, _ = self.sample_rate_network(signals, c, gru_states)
|
|
log_probs = torch.log_softmax(y, dim=-1)
|
|
|
|
return log_probs
|
|
|
|
def generate(self, features, periods, lpcs):
|
|
|
|
with torch.no_grad():
|
|
device = self.parameters().__next__().device
|
|
|
|
num_frames = features.shape[0] - self.feature_history - self.feature_lookahead
|
|
lpc_order = lpcs.shape[-1]
|
|
num_input_signals = len(self.signals)
|
|
pitch_corr_position = self.input_layout['features']['pitch_corr'][0]
|
|
|
|
# signal buffers
|
|
last_signal = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
|
|
prediction = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
|
|
last_error = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
|
|
output = torch.zeros((num_frames * self.frame_size), dtype=torch.int16)
|
|
mem = 0
|
|
|
|
# state buffers
|
|
gru_a_state = torch.zeros((1, 1, self.gru_a_units))
|
|
gru_b_state = torch.zeros((1, 1, self.gru_b_units))
|
|
|
|
input_signals = 128 + torch.zeros(self.group_size_gru_a * num_input_signals, dtype=torch.long)
|
|
# conditioning signals for subconditioner a
|
|
c_signals_a = 128 + torch.zeros(self.group_size_subcondition_a * len(self.signals_idx_a), dtype=torch.long)
|
|
# conditioning signals for subconditioner b
|
|
c_signals_b = 128 + torch.zeros(len(self.signals_idx_b), dtype=torch.long)
|
|
|
|
# signal dict
|
|
signal_dict = {
|
|
'prediction' : prediction,
|
|
'last_error' : last_error,
|
|
'last_signal' : last_signal
|
|
}
|
|
|
|
# push data to device
|
|
features = features.to(device)
|
|
periods = periods.to(device)
|
|
lpcs = lpcs.to(device)
|
|
|
|
# run feature encoding
|
|
c = self.frame_rate_network(features.unsqueeze(0), periods.unsqueeze(0))
|
|
|
|
for frame_index in range(num_frames):
|
|
frame_start = frame_index * self.frame_size
|
|
pitch_corr = features[frame_index + self.feature_history, pitch_corr_position]
|
|
a = - torch.flip(lpcs[frame_index + self.feature_history], [0])
|
|
current_c = c[:, frame_index : frame_index + 1, :]
|
|
|
|
for i in range(0, self.frame_size, self.group_size_gru_a):
|
|
pcm_position = frame_start + i + lpc_order
|
|
output_position = frame_start + i
|
|
|
|
# calculate newest prediction
|
|
prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a)
|
|
|
|
# prepare input
|
|
for slot in range(self.group_size_gru_a):
|
|
k = slot - self.group_size_gru_a + 1
|
|
for idx, name in enumerate(self.signals):
|
|
input_signals[idx + slot * num_input_signals] = lin2ulawq(
|
|
signal_dict[name][pcm_position + k]
|
|
)
|
|
|
|
|
|
# run GRU A
|
|
embed_signals = self.signal_embedding(input_signals.reshape((1, 1, -1)))
|
|
embed_signals = torch.flatten(embed_signals, 2)
|
|
y = torch.cat((embed_signals, current_c), dim=-1)
|
|
h_a, gru_a_state = self.gru_a(y, gru_a_state)
|
|
|
|
# loop over substeps_a
|
|
for step_a in range(self.substeps_a):
|
|
# prepare conditioning input
|
|
for slot in range(self.group_size_subcondition_a):
|
|
k = slot - self.group_size_subcondition_a + 1
|
|
for idx, name in enumerate(self.subcondition_signals_a):
|
|
c_signals_a[idx + slot * num_input_signals] = lin2ulawq(
|
|
signal_dict[name][pcm_position + k]
|
|
)
|
|
|
|
# subconditioning
|
|
h_a = self.subconditioner_a.single_step(step_a, h_a, c_signals_a.reshape((1, 1, -1)))
|
|
|
|
# run GRU B
|
|
y = torch.cat((h_a, current_c), dim=-1)
|
|
h_b, gru_b_state = self.gru_b(y, gru_b_state)
|
|
|
|
# loop over substeps b
|
|
for step_b in range(self.substeps_b):
|
|
# prepare subconditioning input
|
|
for idx, name in enumerate(self.subcondition_signals_b):
|
|
c_signals_b[idx] = lin2ulawq(
|
|
signal_dict[name][pcm_position]
|
|
)
|
|
|
|
# subcondition
|
|
h_b = self.subconditioner_b.single_step(step_b, h_b, c_signals_b.reshape((1, 1, -1)))
|
|
|
|
# run dual FC
|
|
probs = torch.softmax(self.dual_fc[step_b](h_b), dim=-1)
|
|
|
|
# sample
|
|
new_exc = ulaw2lin(sample_excitation(probs, pitch_corr))
|
|
|
|
# update signals
|
|
sig = new_exc + prediction[pcm_position]
|
|
last_error[pcm_position + 1] = new_exc
|
|
last_signal[pcm_position + 1] = sig
|
|
|
|
mem = 0.85 * mem + float(sig)
|
|
output[output_position] = clip_to_int16(round(mem))
|
|
|
|
# increase positions
|
|
pcm_position += 1
|
|
output_position += 1
|
|
|
|
# calculate next prediction
|
|
prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a)
|
|
|
|
return output
|