mirror of
https://github.com/xiph/opus.git
synced 2025-06-04 01:27:42 +00:00
added LPCNet torch implementation
Signed-off-by: Jan Buethe <jbuethe@amazon.de>
This commit is contained in:
parent
90a171c1c2
commit
35ee397e06
38 changed files with 3200 additions and 0 deletions
3
dnn/torch/lpcnet/utils/layers/__init__.py
Normal file
3
dnn/torch/lpcnet/utils/layers/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
from .dual_fc import DualFC
|
||||
from .subconditioner import AdditiveSubconditioner, ModulativeSubconditioner, ConcatenativeSubconditioner
|
||||
from .pcm_embeddings import PCMEmbedding, DifferentiablePCMEmbedding
|
15
dnn/torch/lpcnet/utils/layers/dual_fc.py
Normal file
15
dnn/torch/lpcnet/utils/layers/dual_fc.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
class DualFC(nn.Module):
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(DualFC, self).__init__()
|
||||
|
||||
self.dense1 = nn.Linear(input_dim, output_dim)
|
||||
self.dense2 = nn.Linear(input_dim, output_dim)
|
||||
|
||||
self.alpha = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
|
||||
self.beta = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.alpha * torch.tanh(self.dense1(x)) + self.beta * torch.tanh(self.dense2(x))
|
42
dnn/torch/lpcnet/utils/layers/pcm_embeddings.py
Normal file
42
dnn/torch/lpcnet/utils/layers/pcm_embeddings.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
""" module implementing PCM embeddings for LPCNet """
|
||||
|
||||
import math as m
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PCMEmbedding(nn.Module):
|
||||
def __init__(self, embed_dim=128, num_levels=256):
|
||||
super(PCMEmbedding, self).__init__()
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
self.num_levels = num_levels
|
||||
|
||||
self.embedding = nn.Embedding(self.num_levels, self.num_dim)
|
||||
|
||||
# initialize
|
||||
with torch.no_grad():
|
||||
num_rows, num_cols = self.num_levels, self.embed_dim
|
||||
a = m.sqrt(12) * (torch.rand(num_rows, num_cols) - 0.5)
|
||||
for i in range(num_rows):
|
||||
a[i, :] += m.sqrt(12) * (i - num_rows / 2)
|
||||
self.embedding.weight[:, :] = 0.1 * a
|
||||
|
||||
def forward(self, x):
|
||||
return self.embeddint(x)
|
||||
|
||||
|
||||
class DifferentiablePCMEmbedding(PCMEmbedding):
|
||||
def __init__(self, embed_dim, num_levels=256):
|
||||
super(DifferentiablePCMEmbedding, self).__init__(embed_dim, num_levels)
|
||||
|
||||
def forward(self, x):
|
||||
x_int = (x - torch.floor(x)).detach().long()
|
||||
x_frac = x - x_int
|
||||
x_next = torch.minimum(x_int + 1, self.num_levels)
|
||||
|
||||
embed_0 = self.embedding(x_int)
|
||||
embed_1 = self.embedding(x_next)
|
||||
|
||||
return (1 - x_frac) * embed_0 + x_frac * embed_1
|
468
dnn/torch/lpcnet/utils/layers/subconditioner.py
Normal file
468
dnn/torch/lpcnet/utils/layers/subconditioner.py
Normal file
|
@ -0,0 +1,468 @@
|
|||
from re import sub
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
|
||||
|
||||
def get_subconditioner( method,
|
||||
number_of_subsamples,
|
||||
pcm_embedding_size,
|
||||
state_size,
|
||||
pcm_levels,
|
||||
number_of_signals,
|
||||
**kwargs):
|
||||
|
||||
subconditioner_dict = {
|
||||
'additive' : AdditiveSubconditioner,
|
||||
'concatenative' : ConcatenativeSubconditioner,
|
||||
'modulative' : ModulativeSubconditioner
|
||||
}
|
||||
|
||||
return subconditioner_dict[method](number_of_subsamples,
|
||||
pcm_embedding_size, state_size, pcm_levels, number_of_signals, **kwargs)
|
||||
|
||||
|
||||
class Subconditioner(nn.Module):
|
||||
def __init__(self):
|
||||
""" upsampling by subconditioning
|
||||
|
||||
Upsamples a sequence of states conditioning on pcm signals and
|
||||
optionally a feature vector.
|
||||
"""
|
||||
super(Subconditioner, self).__init__()
|
||||
|
||||
def forward(self, states, signals, features=None):
|
||||
raise Exception("Base class should not be called")
|
||||
|
||||
def single_step(self, index, state, signals, features):
|
||||
raise Exception("Base class should not be called")
|
||||
|
||||
def get_output_dim(self, index):
|
||||
raise Exception("Base class should not be called")
|
||||
|
||||
|
||||
class AdditiveSubconditioner(Subconditioner):
|
||||
def __init__(self,
|
||||
number_of_subsamples,
|
||||
pcm_embedding_size,
|
||||
state_size,
|
||||
pcm_levels,
|
||||
number_of_signals,
|
||||
**kwargs):
|
||||
""" subconditioning by addition """
|
||||
|
||||
super(AdditiveSubconditioner, self).__init__()
|
||||
|
||||
self.number_of_subsamples = number_of_subsamples
|
||||
self.pcm_embedding_size = pcm_embedding_size
|
||||
self.state_size = state_size
|
||||
self.pcm_levels = pcm_levels
|
||||
self.number_of_signals = number_of_signals
|
||||
|
||||
if self.pcm_embedding_size != self.state_size:
|
||||
raise ValueError('For additive subconditioning state and embedding '
|
||||
+ f'sizes must match but but got {self.state_size} and {self.pcm_embedding_size}')
|
||||
|
||||
self.embeddings = [None]
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
|
||||
self.add_module('pcm_embedding_' + str(i), embedding)
|
||||
self.embeddings.append(embedding)
|
||||
|
||||
def forward(self, states, signals):
|
||||
""" creates list of subconditioned states
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
states : torch.tensor
|
||||
states of shape (batch, seq_length // s, state_size)
|
||||
signals : torch.tensor
|
||||
signals of shape (batch, seq_length, number_of_signals)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
c_states : list of torch.tensor
|
||||
list of s subconditioned states
|
||||
"""
|
||||
|
||||
s = self.number_of_subsamples
|
||||
|
||||
c_states = [states]
|
||||
new_states = states
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embed = self.embeddings[i](signals[:, i::s])
|
||||
# reduce signal dimension
|
||||
embed = torch.sum(embed, dim=2)
|
||||
|
||||
new_states = new_states + embed
|
||||
c_states.append(new_states)
|
||||
|
||||
return c_states
|
||||
|
||||
def single_step(self, index, state, signals):
|
||||
""" carry out single step for inference
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
index : int
|
||||
position in subconditioning batch
|
||||
|
||||
state : torch.tensor
|
||||
state to sub-condition
|
||||
|
||||
signals : torch.tensor
|
||||
signals for subconditioning, all but the last dimensions
|
||||
must match those of state
|
||||
|
||||
Returns:
|
||||
c_state : torch.tensor
|
||||
subconditioned state
|
||||
"""
|
||||
|
||||
if index == 0:
|
||||
c_state = state
|
||||
else:
|
||||
embed_signals = self.embeddings[index](signals)
|
||||
c = torch.sum(embed_signals, dim=-2)
|
||||
c_state = state + c
|
||||
|
||||
return c_state
|
||||
|
||||
def get_output_dim(self, index):
|
||||
return self.state_size
|
||||
|
||||
def get_average_flops_per_step(self):
|
||||
s = self.number_of_subsamples
|
||||
flops = (s - 1) / s * self.number_of_signals * self.pcm_embedding_size
|
||||
return flops
|
||||
|
||||
|
||||
class ConcatenativeSubconditioner(Subconditioner):
|
||||
def __init__(self,
|
||||
number_of_subsamples,
|
||||
pcm_embedding_size,
|
||||
state_size,
|
||||
pcm_levels,
|
||||
number_of_signals,
|
||||
recurrent=True,
|
||||
**kwargs):
|
||||
""" subconditioning by concatenation """
|
||||
|
||||
super(ConcatenativeSubconditioner, self).__init__()
|
||||
|
||||
self.number_of_subsamples = number_of_subsamples
|
||||
self.pcm_embedding_size = pcm_embedding_size
|
||||
self.state_size = state_size
|
||||
self.pcm_levels = pcm_levels
|
||||
self.number_of_signals = number_of_signals
|
||||
self.recurrent = recurrent
|
||||
|
||||
self.embeddings = []
|
||||
start_index = 0
|
||||
if self.recurrent:
|
||||
start_index = 1
|
||||
self.embeddings.append(None)
|
||||
|
||||
for i in range(start_index, self.number_of_subsamples):
|
||||
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
|
||||
self.add_module('pcm_embedding_' + str(i), embedding)
|
||||
self.embeddings.append(embedding)
|
||||
|
||||
def forward(self, states, signals):
|
||||
""" creates list of subconditioned states
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
states : torch.tensor
|
||||
states of shape (batch, seq_length // s, state_size)
|
||||
signals : torch.tensor
|
||||
signals of shape (batch, seq_length, number_of_signals)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
c_states : list of torch.tensor
|
||||
list of s subconditioned states
|
||||
"""
|
||||
s = self.number_of_subsamples
|
||||
|
||||
if self.recurrent:
|
||||
c_states = [states]
|
||||
start = 1
|
||||
else:
|
||||
c_states = []
|
||||
start = 0
|
||||
|
||||
new_states = states
|
||||
for i in range(start, self.number_of_subsamples):
|
||||
embed = self.embeddings[i](signals[:, i::s])
|
||||
# reduce signal dimension
|
||||
embed = torch.flatten(embed, -2)
|
||||
|
||||
if self.recurrent:
|
||||
new_states = torch.cat((new_states, embed), dim=-1)
|
||||
else:
|
||||
new_states = torch.cat((states, embed), dim=-1)
|
||||
|
||||
c_states.append(new_states)
|
||||
|
||||
return c_states
|
||||
|
||||
def single_step(self, index, state, signals):
|
||||
""" carry out single step for inference
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
index : int
|
||||
position in subconditioning batch
|
||||
|
||||
state : torch.tensor
|
||||
state to sub-condition
|
||||
|
||||
signals : torch.tensor
|
||||
signals for subconditioning, all but the last dimensions
|
||||
must match those of state
|
||||
|
||||
Returns:
|
||||
c_state : torch.tensor
|
||||
subconditioned state
|
||||
"""
|
||||
|
||||
if index == 0 and self.recurrent:
|
||||
c_state = state
|
||||
else:
|
||||
embed_signals = self.embeddings[index](signals)
|
||||
c = torch.flatten(embed_signals, -2)
|
||||
if not self.recurrent and index > 0:
|
||||
# overwrite previous conditioning vector
|
||||
c_state = torch.cat((state[...,:self.state_size], c), dim=-1)
|
||||
else:
|
||||
c_state = torch.cat((state, c), dim=-1)
|
||||
return c_state
|
||||
|
||||
return c_state
|
||||
|
||||
def get_average_flops_per_step(self):
|
||||
return 0
|
||||
|
||||
def get_output_dim(self, index):
|
||||
if self.recurrent:
|
||||
return self.state_size + index * self.pcm_embedding_size * self.number_of_signals
|
||||
else:
|
||||
return self.state_size + self.pcm_embedding_size * self.number_of_signals
|
||||
|
||||
class ModulativeSubconditioner(Subconditioner):
|
||||
def __init__(self,
|
||||
number_of_subsamples,
|
||||
pcm_embedding_size,
|
||||
state_size,
|
||||
pcm_levels,
|
||||
number_of_signals,
|
||||
state_recurrent=False,
|
||||
**kwargs):
|
||||
""" subconditioning by modulation """
|
||||
|
||||
super(ModulativeSubconditioner, self).__init__()
|
||||
|
||||
self.number_of_subsamples = number_of_subsamples
|
||||
self.pcm_embedding_size = pcm_embedding_size
|
||||
self.state_size = state_size
|
||||
self.pcm_levels = pcm_levels
|
||||
self.number_of_signals = number_of_signals
|
||||
self.state_recurrent = state_recurrent
|
||||
|
||||
self.hidden_size = self.pcm_embedding_size * self.number_of_signals
|
||||
|
||||
if self.state_recurrent:
|
||||
self.hidden_size += self.pcm_embedding_size
|
||||
self.state_transform = nn.Linear(self.state_size, self.pcm_embedding_size)
|
||||
|
||||
self.embeddings = [None]
|
||||
self.alphas = [None]
|
||||
self.betas = [None]
|
||||
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
|
||||
self.add_module('pcm_embedding_' + str(i), embedding)
|
||||
self.embeddings.append(embedding)
|
||||
|
||||
self.alphas.append(nn.Linear(self.hidden_size, self.state_size))
|
||||
self.add_module('alpha_dense_' + str(i), self.alphas[-1])
|
||||
|
||||
self.betas.append(nn.Linear(self.hidden_size, self.state_size))
|
||||
self.add_module('beta_dense_' + str(i), self.betas[-1])
|
||||
|
||||
|
||||
|
||||
def forward(self, states, signals):
|
||||
""" creates list of subconditioned states
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
states : torch.tensor
|
||||
states of shape (batch, seq_length // s, state_size)
|
||||
signals : torch.tensor
|
||||
signals of shape (batch, seq_length, number_of_signals)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
c_states : list of torch.tensor
|
||||
list of s subconditioned states
|
||||
"""
|
||||
s = self.number_of_subsamples
|
||||
|
||||
c_states = [states]
|
||||
new_states = states
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embed = self.embeddings[i](signals[:, i::s])
|
||||
# reduce signal dimension
|
||||
embed = torch.flatten(embed, -2)
|
||||
|
||||
if self.state_recurrent:
|
||||
comp_states = self.state_transform(new_states)
|
||||
embed = torch.cat((embed, comp_states), dim=-1)
|
||||
|
||||
alpha = torch.tanh(self.alphas[i](embed))
|
||||
beta = torch.tanh(self.betas[i](embed))
|
||||
|
||||
# new state obtained by modulating previous state
|
||||
new_states = torch.tanh((1 + alpha) * new_states + beta)
|
||||
|
||||
c_states.append(new_states)
|
||||
|
||||
return c_states
|
||||
|
||||
def single_step(self, index, state, signals):
|
||||
""" carry out single step for inference
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
index : int
|
||||
position in subconditioning batch
|
||||
|
||||
state : torch.tensor
|
||||
state to sub-condition
|
||||
|
||||
signals : torch.tensor
|
||||
signals for subconditioning, all but the last dimensions
|
||||
must match those of state
|
||||
|
||||
Returns:
|
||||
c_state : torch.tensor
|
||||
subconditioned state
|
||||
"""
|
||||
|
||||
if index == 0:
|
||||
c_state = state
|
||||
else:
|
||||
embed_signals = self.embeddings[index](signals)
|
||||
c = torch.flatten(embed_signals, -2)
|
||||
if self.state_recurrent:
|
||||
r_state = self.state_transform(state)
|
||||
c = torch.cat((c, r_state), dim=-1)
|
||||
alpha = torch.tanh(self.alphas[index](c))
|
||||
beta = torch.tanh(self.betas[index](c))
|
||||
c_state = torch.tanh((1 + alpha) * state + beta)
|
||||
return c_state
|
||||
|
||||
return c_state
|
||||
|
||||
def get_output_dim(self, index):
|
||||
return self.state_size
|
||||
|
||||
def get_average_flops_per_step(self):
|
||||
s = self.number_of_subsamples
|
||||
|
||||
# estimate activation by 10 flops
|
||||
# c_state = torch.tanh((1 + alpha) * state + beta)
|
||||
flops = 13 * self.state_size
|
||||
|
||||
# hidden size
|
||||
hidden_size = self.number_of_signals * self.pcm_embedding_size
|
||||
if self.state_recurrent:
|
||||
hidden_size += self.pcm_embedding_size
|
||||
|
||||
# counting 2 * A * B flops for Linear(A, B)
|
||||
# alpha = torch.tanh(self.alphas[index](c))
|
||||
# beta = torch.tanh(self.betas[index](c))
|
||||
flops += 4 * hidden_size * self.state_size + 20 * self.state_size
|
||||
|
||||
# r_state = self.state_transform(state)
|
||||
if self.state_recurrent:
|
||||
flops += 2 * self.state_size * self.pcm_embedding_size
|
||||
|
||||
# average over steps
|
||||
flops *= (s - 1) / s
|
||||
|
||||
return flops
|
||||
|
||||
class ComparitiveSubconditioner(Subconditioner):
|
||||
def __init__(self,
|
||||
number_of_subsamples,
|
||||
pcm_embedding_size,
|
||||
state_size,
|
||||
pcm_levels,
|
||||
number_of_signals,
|
||||
error_index=-1,
|
||||
apply_gate=True,
|
||||
normalize=False):
|
||||
""" subconditioning by comparison """
|
||||
|
||||
super(ComparitiveSubconditioner, self).__init__()
|
||||
|
||||
self.comparison_size = self.pcm_embedding_size
|
||||
self.error_position = error_index
|
||||
self.apply_gate = apply_gate
|
||||
self.normalize = normalize
|
||||
|
||||
self.state_transform = nn.Linear(self.state_size, self.comparison_size)
|
||||
|
||||
self.alpha_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
|
||||
self.beta_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
|
||||
|
||||
if self.apply_gate:
|
||||
self.gate_dense = nn.Linear(self.pcm_embedding_size, self.state_size)
|
||||
|
||||
# embeddings and state transforms
|
||||
self.embeddings = [None]
|
||||
self.alpha_denses = [None]
|
||||
self.beta_denses = [None]
|
||||
self.state_transforms = [nn.Linear(self.state_size, self.comparison_size)]
|
||||
self.add_module('state_transform_0', self.state_transforms[0])
|
||||
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
|
||||
self.add_module('pcm_embedding_' + str(i), embedding)
|
||||
self.embeddings.append(embedding)
|
||||
|
||||
state_transform = nn.Linear(self.state_size, self.comparison_size)
|
||||
self.add_module('state_transform_' + str(i), state_transform)
|
||||
self.state_transforms.append(state_transform)
|
||||
|
||||
self.alpha_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
|
||||
self.add_module('alpha_dense_' + str(i), self.alpha_denses[-1])
|
||||
|
||||
self.beta_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
|
||||
self.add_module('beta_dense_' + str(i), self.beta_denses[-1])
|
||||
|
||||
def forward(self, states, signals):
|
||||
s = self.number_of_subsamples
|
||||
|
||||
c_states = [states]
|
||||
new_states = states
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embed = self.embeddings[i](signals[:, i::s])
|
||||
# reduce signal dimension
|
||||
embed = torch.flatten(embed, -2)
|
||||
|
||||
comp_states = self.state_transforms[i](new_states)
|
||||
|
||||
alpha = torch.tanh(self.alpha_dense(embed))
|
||||
beta = torch.tanh(self.beta_dense(embed))
|
||||
|
||||
# new state obtained by modulating previous state
|
||||
new_states = torch.tanh((1 + alpha) * comp_states + beta)
|
||||
|
||||
c_states.append(new_states)
|
||||
|
||||
return c_states
|
Loading…
Add table
Add a link
Reference in a new issue