opus/dnn/torch/lpcnet/utils/layers/subconditioner.py
Jan Buethe 7b8ba143f1
added copyright headers
Signed-off-by: Jan Buethe <jbuethe@amazon.de>
2023-09-05 22:31:19 +02:00

497 lines
17 KiB
Python

"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
from re import sub
import torch
from torch import nn
def get_subconditioner( method,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
**kwargs):
subconditioner_dict = {
'additive' : AdditiveSubconditioner,
'concatenative' : ConcatenativeSubconditioner,
'modulative' : ModulativeSubconditioner
}
return subconditioner_dict[method](number_of_subsamples,
pcm_embedding_size, state_size, pcm_levels, number_of_signals, **kwargs)
class Subconditioner(nn.Module):
def __init__(self):
""" upsampling by subconditioning
Upsamples a sequence of states conditioning on pcm signals and
optionally a feature vector.
"""
super(Subconditioner, self).__init__()
def forward(self, states, signals, features=None):
raise Exception("Base class should not be called")
def single_step(self, index, state, signals, features):
raise Exception("Base class should not be called")
def get_output_dim(self, index):
raise Exception("Base class should not be called")
class AdditiveSubconditioner(Subconditioner):
def __init__(self,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
**kwargs):
""" subconditioning by addition """
super(AdditiveSubconditioner, self).__init__()
self.number_of_subsamples = number_of_subsamples
self.pcm_embedding_size = pcm_embedding_size
self.state_size = state_size
self.pcm_levels = pcm_levels
self.number_of_signals = number_of_signals
if self.pcm_embedding_size != self.state_size:
raise ValueError('For additive subconditioning state and embedding '
+ f'sizes must match but but got {self.state_size} and {self.pcm_embedding_size}')
self.embeddings = [None]
for i in range(1, self.number_of_subsamples):
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
self.add_module('pcm_embedding_' + str(i), embedding)
self.embeddings.append(embedding)
def forward(self, states, signals):
""" creates list of subconditioned states
Parameters:
-----------
states : torch.tensor
states of shape (batch, seq_length // s, state_size)
signals : torch.tensor
signals of shape (batch, seq_length, number_of_signals)
Returns:
--------
c_states : list of torch.tensor
list of s subconditioned states
"""
s = self.number_of_subsamples
c_states = [states]
new_states = states
for i in range(1, self.number_of_subsamples):
embed = self.embeddings[i](signals[:, i::s])
# reduce signal dimension
embed = torch.sum(embed, dim=2)
new_states = new_states + embed
c_states.append(new_states)
return c_states
def single_step(self, index, state, signals):
""" carry out single step for inference
Parameters:
-----------
index : int
position in subconditioning batch
state : torch.tensor
state to sub-condition
signals : torch.tensor
signals for subconditioning, all but the last dimensions
must match those of state
Returns:
c_state : torch.tensor
subconditioned state
"""
if index == 0:
c_state = state
else:
embed_signals = self.embeddings[index](signals)
c = torch.sum(embed_signals, dim=-2)
c_state = state + c
return c_state
def get_output_dim(self, index):
return self.state_size
def get_average_flops_per_step(self):
s = self.number_of_subsamples
flops = (s - 1) / s * self.number_of_signals * self.pcm_embedding_size
return flops
class ConcatenativeSubconditioner(Subconditioner):
def __init__(self,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
recurrent=True,
**kwargs):
""" subconditioning by concatenation """
super(ConcatenativeSubconditioner, self).__init__()
self.number_of_subsamples = number_of_subsamples
self.pcm_embedding_size = pcm_embedding_size
self.state_size = state_size
self.pcm_levels = pcm_levels
self.number_of_signals = number_of_signals
self.recurrent = recurrent
self.embeddings = []
start_index = 0
if self.recurrent:
start_index = 1
self.embeddings.append(None)
for i in range(start_index, self.number_of_subsamples):
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
self.add_module('pcm_embedding_' + str(i), embedding)
self.embeddings.append(embedding)
def forward(self, states, signals):
""" creates list of subconditioned states
Parameters:
-----------
states : torch.tensor
states of shape (batch, seq_length // s, state_size)
signals : torch.tensor
signals of shape (batch, seq_length, number_of_signals)
Returns:
--------
c_states : list of torch.tensor
list of s subconditioned states
"""
s = self.number_of_subsamples
if self.recurrent:
c_states = [states]
start = 1
else:
c_states = []
start = 0
new_states = states
for i in range(start, self.number_of_subsamples):
embed = self.embeddings[i](signals[:, i::s])
# reduce signal dimension
embed = torch.flatten(embed, -2)
if self.recurrent:
new_states = torch.cat((new_states, embed), dim=-1)
else:
new_states = torch.cat((states, embed), dim=-1)
c_states.append(new_states)
return c_states
def single_step(self, index, state, signals):
""" carry out single step for inference
Parameters:
-----------
index : int
position in subconditioning batch
state : torch.tensor
state to sub-condition
signals : torch.tensor
signals for subconditioning, all but the last dimensions
must match those of state
Returns:
c_state : torch.tensor
subconditioned state
"""
if index == 0 and self.recurrent:
c_state = state
else:
embed_signals = self.embeddings[index](signals)
c = torch.flatten(embed_signals, -2)
if not self.recurrent and index > 0:
# overwrite previous conditioning vector
c_state = torch.cat((state[...,:self.state_size], c), dim=-1)
else:
c_state = torch.cat((state, c), dim=-1)
return c_state
return c_state
def get_average_flops_per_step(self):
return 0
def get_output_dim(self, index):
if self.recurrent:
return self.state_size + index * self.pcm_embedding_size * self.number_of_signals
else:
return self.state_size + self.pcm_embedding_size * self.number_of_signals
class ModulativeSubconditioner(Subconditioner):
def __init__(self,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
state_recurrent=False,
**kwargs):
""" subconditioning by modulation """
super(ModulativeSubconditioner, self).__init__()
self.number_of_subsamples = number_of_subsamples
self.pcm_embedding_size = pcm_embedding_size
self.state_size = state_size
self.pcm_levels = pcm_levels
self.number_of_signals = number_of_signals
self.state_recurrent = state_recurrent
self.hidden_size = self.pcm_embedding_size * self.number_of_signals
if self.state_recurrent:
self.hidden_size += self.pcm_embedding_size
self.state_transform = nn.Linear(self.state_size, self.pcm_embedding_size)
self.embeddings = [None]
self.alphas = [None]
self.betas = [None]
for i in range(1, self.number_of_subsamples):
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
self.add_module('pcm_embedding_' + str(i), embedding)
self.embeddings.append(embedding)
self.alphas.append(nn.Linear(self.hidden_size, self.state_size))
self.add_module('alpha_dense_' + str(i), self.alphas[-1])
self.betas.append(nn.Linear(self.hidden_size, self.state_size))
self.add_module('beta_dense_' + str(i), self.betas[-1])
def forward(self, states, signals):
""" creates list of subconditioned states
Parameters:
-----------
states : torch.tensor
states of shape (batch, seq_length // s, state_size)
signals : torch.tensor
signals of shape (batch, seq_length, number_of_signals)
Returns:
--------
c_states : list of torch.tensor
list of s subconditioned states
"""
s = self.number_of_subsamples
c_states = [states]
new_states = states
for i in range(1, self.number_of_subsamples):
embed = self.embeddings[i](signals[:, i::s])
# reduce signal dimension
embed = torch.flatten(embed, -2)
if self.state_recurrent:
comp_states = self.state_transform(new_states)
embed = torch.cat((embed, comp_states), dim=-1)
alpha = torch.tanh(self.alphas[i](embed))
beta = torch.tanh(self.betas[i](embed))
# new state obtained by modulating previous state
new_states = torch.tanh((1 + alpha) * new_states + beta)
c_states.append(new_states)
return c_states
def single_step(self, index, state, signals):
""" carry out single step for inference
Parameters:
-----------
index : int
position in subconditioning batch
state : torch.tensor
state to sub-condition
signals : torch.tensor
signals for subconditioning, all but the last dimensions
must match those of state
Returns:
c_state : torch.tensor
subconditioned state
"""
if index == 0:
c_state = state
else:
embed_signals = self.embeddings[index](signals)
c = torch.flatten(embed_signals, -2)
if self.state_recurrent:
r_state = self.state_transform(state)
c = torch.cat((c, r_state), dim=-1)
alpha = torch.tanh(self.alphas[index](c))
beta = torch.tanh(self.betas[index](c))
c_state = torch.tanh((1 + alpha) * state + beta)
return c_state
return c_state
def get_output_dim(self, index):
return self.state_size
def get_average_flops_per_step(self):
s = self.number_of_subsamples
# estimate activation by 10 flops
# c_state = torch.tanh((1 + alpha) * state + beta)
flops = 13 * self.state_size
# hidden size
hidden_size = self.number_of_signals * self.pcm_embedding_size
if self.state_recurrent:
hidden_size += self.pcm_embedding_size
# counting 2 * A * B flops for Linear(A, B)
# alpha = torch.tanh(self.alphas[index](c))
# beta = torch.tanh(self.betas[index](c))
flops += 4 * hidden_size * self.state_size + 20 * self.state_size
# r_state = self.state_transform(state)
if self.state_recurrent:
flops += 2 * self.state_size * self.pcm_embedding_size
# average over steps
flops *= (s - 1) / s
return flops
class ComparitiveSubconditioner(Subconditioner):
def __init__(self,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
error_index=-1,
apply_gate=True,
normalize=False):
""" subconditioning by comparison """
super(ComparitiveSubconditioner, self).__init__()
self.comparison_size = self.pcm_embedding_size
self.error_position = error_index
self.apply_gate = apply_gate
self.normalize = normalize
self.state_transform = nn.Linear(self.state_size, self.comparison_size)
self.alpha_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
self.beta_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
if self.apply_gate:
self.gate_dense = nn.Linear(self.pcm_embedding_size, self.state_size)
# embeddings and state transforms
self.embeddings = [None]
self.alpha_denses = [None]
self.beta_denses = [None]
self.state_transforms = [nn.Linear(self.state_size, self.comparison_size)]
self.add_module('state_transform_0', self.state_transforms[0])
for i in range(1, self.number_of_subsamples):
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
self.add_module('pcm_embedding_' + str(i), embedding)
self.embeddings.append(embedding)
state_transform = nn.Linear(self.state_size, self.comparison_size)
self.add_module('state_transform_' + str(i), state_transform)
self.state_transforms.append(state_transform)
self.alpha_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
self.add_module('alpha_dense_' + str(i), self.alpha_denses[-1])
self.beta_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
self.add_module('beta_dense_' + str(i), self.beta_denses[-1])
def forward(self, states, signals):
s = self.number_of_subsamples
c_states = [states]
new_states = states
for i in range(1, self.number_of_subsamples):
embed = self.embeddings[i](signals[:, i::s])
# reduce signal dimension
embed = torch.flatten(embed, -2)
comp_states = self.state_transforms[i](new_states)
alpha = torch.tanh(self.alpha_dense(embed))
beta = torch.tanh(self.beta_dense(embed))
# new state obtained by modulating previous state
new_states = torch.tanh((1 + alpha) * comp_states + beta)
c_states.append(new_states)
return c_states